In February 2026 I ran a bunch of experiments trying to figure out which open-source models work best for self-hosted inference on workstation hardware. I rented 2x RTX PRO 6000 Blackwell GPUs on RunPod, served models with vLLM, and used benchmarks like HumanEval+, GSM8K, IFEval, and MATH Hard as sanity checks — not to score the models, but to make sure quantization and inference were working correctly and to find the right parameters. This page has all the raw data from those experiments.

Related post: Running Open Models (2026-Feb)

Test Hardware
GPUs 2x RTX PRO 6000 Blackwell
Total VRAM 192 GB GDDR7
Compute SM120
Interconnect PCIe
Serving Engine vLLM 0.15–0.16
Platform RunPod
Required workarounds: NCCL_P2P_DISABLE=1 · --disable-custom-all-reduce · --enforce-eager (some configs)

Summary

These benchmark scores are for verifying that inference is working, not for ranking models against each other. Most of these benchmarks are saturated and don't reflect real-world capability. Take the numbers as “is vLLM producing correct output with this quantization?” rather than “which model is smarter?”

Quality Sanity Checks

IFEval (Instruction Following) — higher is better
Qwen3.5-122B FP8
86.9%
Qwen3.5-35B FP8
86.1%
Qwen3.5-122B NVFP4
85.6%
Qwen3.5-35B AWQ
85.4%
gpt-oss-120b
76.9%
MiniMax M2.5
37.2%
Perplexity — lower is better
Qwen3.5-122B FP8
7.07
Qwen3.5-122B NVFP4
7.48
Qwen3.5-35B FP8
9.59
MiniMax M2.5
12.45
gpt-oss-120b
30.25
MATH Hard — higher is better
Qwen3.5-122B NVFP4
57.3%
Qwen3.5-122B FP8
56.8%
gpt-oss-120b
55.5%
MiniMax M2.5
44.9%

Speed: Batched Throughput (tok/s)

Output tokens per second at rate=16 — higher is better
Qwen3.5-35B AWQ
2,817 tok/s
gpt-oss-120b
2,656 tok/s
Qwen3.5-35B FP8
2,623 tok/s
Qwen3.5-122B NVFP4
1,759 tok/s
MiniMax M2.5
1,688 tok/s
Qwen3.5-122B FP8
1,569 tok/s

Speed: Single-Stream Latency (tok/s)

True single-user experience at concurrency=1 (one request finishes before the next starts). This is the interactive speed you feel.

Output tokens per second at concurrency=1 — higher is better
Qwen3.5-122B NVFP4
102 tok/s (9.2ms/tok)
Qwen3.5-122B FP8
91 tok/s (10.4ms/tok)
gpt-oss-120b
62 tok/s (16.1ms/tok)

Concurrent Full-Context Sessions (128K tokens)

How many simultaneous 128K-token conversations fit in VRAM after loading model weights. Determined by remaining memory available for KV cache at gpu_memory_utilization=0.90.

Simultaneous 128K-token sessions that fit in 192 GB VRAM
Qwen3.5-35B AWQ
54 sessions
Qwen3.5-35B FP8
49 sessions
Qwen3.5-122B NVFP4
28 sessions
gpt-oss-120b
20 sessions
Qwen3.5-122B FP8
14 sessions
MiniMax M2.5
2 sessions

Full Results Table

Model Quant IFEval MATH Hard Perplexity 1-stream rate=16 GiB/GPU 128K Sess.
Qwen3.5-122B-A10B NVFP4 85.6% 57.3% 7.48 102 1,759 35.8 28
Qwen3.5-122B-A10B FP8 86.9% 56.8% 7.07 91 1,569 64 14
Qwen3.5-35B-A3B FP8 86.1% 9.59 2,623 49
Qwen3.5-35B-A3B AWQ 85.4% 2,817 54
gpt-oss-120b MXFP4 76.9% 55.5% 30.25 62 2,656 32 20
MiniMax M2.5 NVFP4 37.2% 44.9% 12.45 1,688 2

Experiment Details

Project 1: Initial Pod Validation Feb 19

First time setting up the RunPod pod with these GPUs. Immediately ran into a hang — the RTX PRO 6000 reports SM120 compute capability, and vLLM's detection logic (120 // 10 = 12) doesn't recognize it as SM100 family. So a lot of optimized kernels just don't load: FlashInfer, VLLM_CUTLASS, some DeepGEMM variants — all gated behind is_device_capability_family(100). The kind of thing you only find out by running into it.

Required for all experiments

NCCL_P2P_DISABLE=1 and --disable-custom-all-reduce to get tensor parallelism working over PCIe at all.

Project 2: MiniMax M2.5 vs Step 3.5 Flash Feb 19

Tried to get MiniMax M2.5 NVFP4 (~126 GB) and Step 3.5 Flash running. MiniMax loaded fine.

Step 3.5 Flash couldn't run at all on this hardware/vLLM combo:

  • NVFP4 — not supported for MoE models (vLLM issue #31782)
  • GGUF Q4 — vLLM dequantizes to BF16, exceeding VRAM
  • FP8 — weights alone exceed 192 GB capacity
Lesson

Not all quantization formats work with all architectures. NVFP4 + MoE support is incomplete in vLLM.

Project 3: MiniMax M2.5 vs Qwen3-235B bugs found Feb 20

This one was a mess. The MiniMax quality scores came out terribly wrong and it took a while to figure out why. Turned out to be three bugs stacking on top of each other.

Speed Results (Valid)

MetricMiniMax M2.5Qwen3-235B
Output tok/s6642
Throughput (rate=20)2.16 req/s1.64 req/s

Original Quality (Broken)

BenchmarkQwen3-235BMiniMax M2.5
HumanEval+94.5%49.7%
GSM8K88.2%54.8%
IFEval86.0%25.1%
Bug 1 — Think-Tag Proxy Failure

MiniMax's chat template injects <think> as a forced prefix — it's not part of generated output. The model only generates </think>. The proxy regex required both tags, so it never matched. Reasoning content was never stripped. Zero opening tags found in 1,450+ GSM8K samples.

Bug 2 — Model Version Mismatch

Published benchmarks are for MiniMax-Text-01 (Jan 2025, 456B params). The model tested was MiniMax M2.5 (Feb 2026, reasoning/agentic model). M2.5 does not publish HumanEval, GSM8K, or IFEval scores.

Bug 3 — MMLU Evaluation Artifact

Used mmlu_generative instead of standard loglikelihood due to API constraints. Post-processing extraction succeeded at only 65–75%. Scores unreliable for both models.

After Fix

With the proxy fixed and max_gen_toks increased from 256 to 4096:

54.8%
GSM8K Before
broken proxy
95.75%
GSM8K After
+41 percentage points
Project 4: OpenAI gpt-oss-120b on SM120 Feb 21

Tried running gpt-oss-120b (117B MoE, 128 experts, 4 active) on SM120. This model uses OpenAI's Harmony response format (reasoning_content field) instead of <think> tags, which avoids the whole think-tag mess.

Backend Compatibility on SM120

VariantBackendStatus
MXFP4 Marlin works
FP8-dynamic Triton garbage output

FP8 backend cascade on SM120:

  1. FLASHINFER — SM100 only, unavailable
  2. DEEPGEMM — disabled or produces garbage
  3. TRITON — selected but produces garbage tokens
  4. MARLIN — crashes with thread config errors

Quality Results (MXFP4)

95.1%
GSM8K
76.3%
IFEval
45.7%
HumanEval+
46% extraction failure

Speed

44
tok/s single request
2,764
tok/s peak batched
rate=16

Working Configuration

NCCL_P2P_DISABLE=1 vllm serve openai/gpt-oss-120b \
    --tensor-parallel-size 2 \
    --enforce-eager \
    --tool-call-parser openai \
    --reasoning-effort low
Project 5: MiniMax M2.5 FP8-INT4-AWQ artifact Feb 21

Ran mratsim/MiniMax-M2.5-FP8-INT4-AWQ (~92.5 GB) through the full benchmark suite. Found a weird perplexity artifact baked into the AWQ weights.

Quality with Reasoning Parser

BenchmarkWith ParserWithoutEffect
HumanEval+ pass@129.9%23.2%+38% relative
GSM8K96.59%
IFEval48.8%
MATH Hard45.8%
GPQA Diamond77.78%
Perplexity Artifact

Measured perplexity: 362.56 (vs 12.46 for NVFP4 of the same model). Token 104256 () dominates the logprob distribution. This is a weight quantization artifact — not a backend issue. Loglikelihood evaluations (MMLU, GPQA loglik) are unusable for this quantization.

Project 6: B200 SM100 Comparison Feb 22

Ran the same models on a B200 (SM100) to see if the kernel issues were Blackwell-wide or specific to SM120. Turns out SM100 supports different and often better kernel paths. So the problems I hit are SM120-specific — workstation Blackwell, not data center Blackwell.

Takeaway

SM120 (workstation) has fewer optimized paths than SM100 (data center). If you're deploying on RTX PRO cards, expect to hit kernel issues that B200 users won't see.

Project 7: Marlin vs Triton MXFP4 Backends Feb 23

I found a vLLM PR (#31089) that enables Triton on SM120, so I tested it head-to-head against the stock Marlin backend using gpt-oss-120b. The Triton patch technically works, but it completely crashes the performance.

Output Throughput (tok/s)
Marlin
575.67
Triton
101.12
MetricMarlinTritonRatio
Output tok/s575.67101.125.7x
Mean TPOT16.64ms146.62ms8.8x slower
Median TTFT165.87ms11,079ms66.8x slower
Verdict

Just use Marlin. The Triton SM120 path technically runs but it's not usable in practice.

Project 8: AWQ Logprob & Reasoning Parser A/B Feb 22

Question 1: Is the AWQ artifact backend-dependent?

ConfigurationPerplexity
Baseline (DeepGEMM)350.19
No DeepGEMM + eager346.50
FlashInfer345.51

All variants show ~350 perplexity. The artifact is in the model weights, not the backend.

Question 2: Reasoning parser per-benchmark effect

BenchmarkWith ParserWithoutEffect
HumanEval+32.9%23.8%+38%
GSM8K96.51%96.66%Neutral
IFEval47.72%48.92%Slight −
MATH Hard44.64%44.86%Neutral
Verdict

The reasoning parser helps a lot for code generation (+38% on HumanEval) but doesn't do much for anything else. Use it selectively.

Project 9: MoE vs Dense Architecture headline result Feb 24

This was the experiment I was most curious about. Same model family, two architectures: MoE (35B total, only 3B active per token) vs Dense (27B, all active). The MoE model worked way better than I expected.

Qwen3.5-35B-A3B

MoE
35B total · 3B active
vs

Qwen3.5-27B

Dense
27B total · 27B active

Speed Comparison

Both models served on 2x RTX PRO 6000 (TP=2) via vLLM. “rate=N” means N new requests arriving per second. Throughput is total output tokens per second across all in-flight requests. TPOT (time per output token) is per-request latency—how long a user waits between each streamed word. TTFT (time to first token) is how long until the first word appears after sending a prompt.

MetricMoE (35B-A3B)Dense (27B)MoE vs Dense
Throughput (rate=4: 4 req/s arriving)879.1 tok/s445.3 tok/s1.97x
Throughput (rate=16: 16 req/s arriving)2,332 tok/s1,369.8 tok/s1.70x
TPOT (rate=4: per-user streaming latency)26.66ms65.12ms2.4x lower
TTFT (rate=4: wait for first word)2,008ms26,792ms13.3x lower

Simultaneous Full-Context Sessions

After loading weights, remaining VRAM determines how many 131K-token conversations can run at the same time. MoE activates fewer parameters per token, producing smaller KV cache entries and leaving more room for concurrent sessions.

131K-token sessions fitting in 192 GB VRAM
MoE (35B-A3B)
37 sessions
Dense (27B)
13 sessions
Bottom line

MoE with only 3B active parameters gives you about 2x the batched throughput and nearly 3x as many simultaneous conversations in the same VRAM. Despite having more total parameters, it wins on every metric that matters for serving.

Project 10: Qwen3.5 Thinking Mode Quality Fix fixes IFEval Feb 25

The quality scores from Project 9 were suspiciously bad (~29% IFEval, ~9% MATH Hard). Turns out thinking mode was silently corrupting all the benchmark results. I'm still not fully sure if this is a fundamental limitation of the benchmark tooling with reasoning models, or just a configuration problem — but disabling thinking fixed everything.

The Problem

  1. Chat template injects <think> as a prefix (not generated by the model)
  2. Model only generates </think> as a closing tag
  3. vLLM's --reasoning-parser qwen3 expects both tags → fails silently
  4. Thinking content remains in content field, corrupting extraction
  5. 96% of MATH samples had untagged reasoning that received no stripping
  6. Greedy decoding with 1024-token cap causes thinking to exhaust budget before answer

The Fix

--default-chat-template-kwargs '{"enable_thinking": false}'

Results

VariantConfigMATH HardIFEval
V0 (baseline) qwen3 parser, proxy, greedy 9.0% 29.0%
V1 (no-think) enable_thinking=False, greedy 58.2% 84.8%
V3 (no-think-warm) enable_thinking=False, temp=0.7 57.5% 86.1%
V2 (deepseek_r1) deepseek_r1 parser, proxy 0.3% 7.9%
V4 (full budget) deepseek_r1 parser, 32768 tokens 70.5% 8.9%
6.5x
MATH Hard improvement
9% → 58.2%
2.9x
IFEval improvement
29% → 84.8%
Project 11: Qwen3.5 Vision — 1 GPU vs 2 GPU Feb 27

Good to know: Qwen3.5-35B-A3B handles images natively (early-fusion, same weights, no separate adapter). I tested vision on 1 GPU vs 2 to see if splitting helps.

Text Quality (No Degradation)

Benchmark1-GPU2-GPU
IFEval (strict)85.03%85.58%
MATH Hard57.40%56.80%

Vision Quality

Benchmark1-GPU2-GPU
MMMU (val)52.67%51.78%
MMBench EN87.88%87.88%

Speed

MetricWhat It Measures1-GPU2-GPUSpeedup
Text throughputOutput tokens generated per second across concurrent text requests (rate=inf)1,119 tok/s1,630 tok/s1.46x
Vision sequentialImage+text requests completed per second, one at a time (20 requests)1.91 req/s2.24 req/s1.17x
Vision concurrentImage+text requests completed per second with 8 in-flight at once (50 requests)3.44 req/s4.37 req/s1.27x
Vision first-requestTime from sending the first image+text request to receiving a complete response4.14s0.46s9x lower latency
Verdict

Two GPUs help — 1.46x more text throughput and 9x lower vision first-request latency (4.14s → 0.46s) — without hurting quality on any benchmark.

Project 12: Quantization Comparison (Single GPU) Feb 28

Compared BF16, FP8, and AWQ-4bit of Qwen3.5-35B-A3B on a single card to see how much speed you gain and how much quality you lose.

Quality

QuantPerplexityIFEvalMMMU
BF169.567.6%*25.0%*
FP89.587.6%*24.8%*
AWQ-4bit8.9%*26.1%*

*Scores collapsed due to thinking mode being enabled. FP8 perplexity is essentially identical to BF16 (9.58 vs 9.56).

Speed

Throughput (tok/s) — higher is better
AWQ-4bit
1,955 (+77%)
FP8
1,581 (+43%)
BF16
1,103 (base)
QuantTPOTTTFTThroughputvs BF16
BF1626.3ms287.4ms1,103 tok/sbaseline
FP817.8ms269.9ms1,581 tok/s+43%
AWQ-4bit13.9ms278.6ms1,955 tok/s+77%
Verdict

FP8 is the sweet spot. Same quality as BF16, 43% more throughput on a single GPU (1,581 vs 1,103 tok/s), and at --gpu-memory-utilization 0.62 it only uses ~62 GB on a 96 GB card — leaving 34 GB for KV cache. AWQ-4bit is even quicker but the perplexity artifact makes it harder to trust.

Project 13: Final Verification — 6 Model Variants definitive Mar 1

Re-ran all models with consistent methodology: vllm bench serve with 512 input / 256 output tokens, same gpu_memory_utilization=0.90 and max_model_len=131072 across the board. Added Qwen3.5-122B-A10B (the bigger MoE) to the lineup.

Speed (512in/256out, 2x RTX PRO 6000)

Model Params Quant TTFT rate=1 rate=4 rate=16
Qwen3.5-35B-A3B 35B (3B active) FP8 61ms 245 918 2,623
Qwen3.5-35B-A3B 35B (3B active) AWQ 66ms 247 941 2,817
Qwen3.5-122B-A10B 122B (10B active) FP8 105ms 233 744 1,569
Qwen3.5-122B-A10B 122B (10B active) NVFP4 106ms 236 789 1,694
MiniMax M2.5 456B NVFP4 147ms 227 726 1,688
gpt-oss-120b 120B MXFP4 70ms 237 876 2,656

Quality

ModelIFEvalMATH HardPerplexity
Qwen3.5-35B FP886.1%—*9.59
Qwen3.5-35B AWQ85.4%—*
Qwen3.5-122B FP886.9%56.8%7.07
Qwen3.5-122B NVFP485.6%57.3%7.48
MiniMax M2.537.2%44.9%12.45
gpt-oss-120b76.9%55.5%30.25

*MATH Hard failed for Qwen3.5-35B variants due to missing math-verify dependency (installed mid-run).

KV Cache & Concurrency

At gpu_memory_utilization=0.90 and max_model_len=131072, how many concurrent 128K-token conversations fit in the remaining VRAM after loading weights:

ModelQuantKV MemoryKV Tokens128K Sessions
Qwen3.5-35B-A3BAWQ69.96 GiB1,833,21654
Qwen3.5-35B-A3BFP863.96 GiB1,675,87249
Qwen3.5-122B-A10BNVFP445.35 GiB989,31228
gpt-oss-120bMXFP447.84 GiB1,393,50420
Qwen3.5-122B-A10BFP822.66 GiB494,65614
MiniMax M2.5NVFP419.26 GiB325,7282

Key Takeaways

Findings

Qwen3.5-122B-A10B FP8 is the quality winner — best IFEval (86.9%), best perplexity (7.07), and solid MATH Hard (56.8%). Still 233 tok/s single-user.

gpt-oss-120b is surprisingly fast — 2,656 tok/s at rate=16 and only 70ms TTFT, matching the 35B models despite being 4x larger.

NVFP4 doubles KV cache for 122B — 28 vs 14 concurrent 128K conversations, because model weights are ~half the size, freeing 23 GiB for KV cache.

MiniMax M2.5 disappoints on quality — 37.2% IFEval is very low, likely a prompt template issue with the think-tag proxy. Only 2 concurrent 128K sessions.

Project 14: SM120 Backend Deep Dive & NVFP4 Optimization breakthrough Mar 2

Two-part investigation: (1) why SM120 is stuck on the Triton fallback for FP8 MoE, and (2) whether NVFP4 quantization unlocks better performance via native SM120 kernels.

SM120 vs Datacenter Blackwell

The RTX PRO 6000 (SM120) shares the Blackwell brand with B200/B100 (SM100) but has a fundamentally different instruction set:

FeatureSM90 (Hopper)SM100 (B200)SM120 (RTX PRO 6000)
wgmma (Warp Group MMA)YesYesNo
tcgen05/UMMANoYesNo
FP8 tensor coresYesYesYes
FP4 tensor coresNoYesYes
Max shared memory228 KB232 KB99 KB
Memory bandwidth3.35 TB/s8 TB/s1.79 TB/s

SM120 has FP8 and FP4 tensor cores but can only access them via mma.sync (SM89-era instruction), not the high-throughput wgmma or tcgen05 used by optimized backends.

Backend Feasibility for FP8 MoE on SM120

BackendStatusBlocker
FlashInfer TRT-LLMImpossiblePre-compiled SM100 cubins only, no source
DeepGEMMImpossiblewgmma + 232KB shared mem required
FlashInfer CUTLASS FP8Very hardSM90-gated block-scale kernels
vLLM CUTLASSImpossibleNot compiled for SM120
TritonWorksFallback, SM ≥ 8.9
MarlinWorksWeight-only dequant to BF16

The NVFP4 Path Forward

SM120 has one capability SM90 does not: FP4 tensor cores. FlashInfer CUTLASS already ships SM120-specific FP4 MoE kernels that use mma.sync with FP4 data types and fit within 99 KB shared memory.

NVFP4 Backend Comparison (Qwen3.5-122B-A10B)

ConfigMoE Backend1-stream tok/sTPOTrate=16GiB/GPU
FP8 Triton (baseline) TRITON 91.4 10.36ms 1,569 ~64
NVFP4 CUTLASS FLASHINFER_CUTLASS 89.3 10.59ms 1,737 35.8
NVFP4 Marlin MARLIN 102.1 9.19ms 1,759 35.8
NVFP4 Marlin+EP MARLIN 98.5 9.57ms 1,745 35.8
Findings

NVFP4 + Marlin is the optimal 122B configuration on SM120. 102 tok/s single-stream (12% faster than FP8), 9.19ms per-token latency, 1,759 tok/s batched (12% more throughput), and 44% less memory per GPU (35.8 vs 64 GiB).

Marlin beats native CUTLASS at single-stream (102 vs 89 tok/s). At batch size 1 the workload is memory-bound (streaming expert weights from HBM), and Marlin’s BF16 dequant-and-compute path is faster than CUTLASS’s native FP4. Under concurrency both converge (within 1.3% at rate=16).

Quality verified — 12-test evaluation (math, logic, code, translation, creative writing) showed no degradation from NVFP4 quantization.

Project 15: Agent Harness Testing practical Mar 2

Tested two open-source AI coding agents — opencode and hermes-agent — against local models on a practical task: clone a real project, understand it, install dependencies, and investigate vLLM integration.

Round 1: Small Models

AgentModelDurationTool CallsQuality
opencodeMiniMax M2.5 (hosted)~2 min8Excellent
opencodeTrinity Large (hosted)~2 min12+Excellent
hermesQwen3-30B-A3B (local)~16s2Poor

Round 2: Qwen3.5-122B-A10B (2x GPU)

AgentQuantDurationTool CallsQuality
opencodeFP835s4+Good*
opencodeNVFP421s6+Good*
hermesFP888s48Excellent
hermesNVFP479s39Excellent

*opencode hit context window limits (32K output default consumes most of 65K context).

FP8 vs NVFP4 for Agentic Use

MetricFP8NVFP4+MarlinDelta
GPU memory (model)58.2 GiB35.75 GiB-39%
vLLM ready time~175s~135s-23%
opencode duration35s21s-40%
hermes duration88s79s-10%
Output qualityExcellentExcellent=
Findings

Model size is the dominant factor for agentic quality. Qwen3-30B-A3B (3B active) made 2 tool calls. Qwen3.5-122B (10B active) made 39–48 tool calls with correct conclusions. The 3.3x increase in active parameters produced a qualitative leap.

hermes-agent is excellent with a capable model. Round 1’s poor results were entirely model-limited. With 122B, hermes produced the most thorough reports of any combination — reading docs, searching code, installing dependencies, and producing structured reports.

NVFP4+Marlin shows no quality difference from FP8 in agentic use, while being 39% lighter on memory and 10–40% faster end-to-end.

Cross-Cutting Lessons

Thinking tags are a nightmare

Every model does it differently. MiniMax injects <think> as a template prefix, Qwen3.5 does the same, gpt-oss-120b uses a completely separate field. There's no universal way to strip reasoning content. For benchmarks, just disable thinking mode entirely.

SM120 is not SM100 — it lacks wgmma

The RTX PRO 6000 shares the Blackwell brand with B200 but has a fundamentally different ISA. It lacks wgmma and tcgen05 instructions, blocking DeepGEMM, TRT-LLM MoE, and CUTLASS FP8 block-scale kernels. It also has only 99 KB shared memory vs 228+ KB on datacenter GPUs. The fallback is Triton for FP8, or Marlin for NVFP4/MXFP4 (which is actually faster for memory-bound single-stream workloads).

Quantization artifacts are in the weights, not the backend

The AWQ logprob anomaly (token 104256 dominating everything) showed up identically across DeepGEMM, eager mode, and FlashInfer. It's baked into the checkpoint. You have to validate quantized models independently.

MoE is the way to go

Qwen3.5-35B-A3B with only 3B active parameters gives you 2x the batched throughput and nearly 3x the concurrent sessions vs a dense 27B model. For self-hosted inference where you're paying for VRAM, MoE is a no-brainer.

NVFP4 + Marlin is the sweet spot on SM120

For Qwen3.5-122B on workstation Blackwell: NVFP4 with --moe-backend marlin gives 102 tok/s single-stream (12% faster than FP8 Triton), uses 44% less GPU memory, and shows no quality degradation in either benchmarks or agentic tool use. Marlin's BF16 dequant-and-compute beats CUTLASS native FP4 when memory-bound at batch size 1.

Reasoning models and benchmarks don't mix well

Default configs produce silently wrong results. Token budgets need to be 4096+, greedy decoding might be incompatible, and reasoning content leaks into answers unless you explicitly disable it. I'm not sure if this is a benchmark tooling problem or something more fundamental.

vLLM dependency hell is real

lm-eval, evalplus, and lmms-eval all pin different vLLM versions and will silently downgrade your install. Always install vLLM last with --upgrade.

Production Recommendation

Recommended Production Configuration
vllm serve Qwen/Qwen3.5-35B-A3B-FP8 \
    --tensor-parallel-size 2 \
    --dtype auto \
    --max-model-len 131072 \
    --gpu-memory-utilization 0.90 \
    --seed 42 \
    --default-chat-template-kwargs '{"enable_thinking": false}' \
    --disable-custom-all-reduce \
    --port 8000

Why this model

I'm going with Qwen3.5-35B-A3B in FP8. It's the best combination of speed and capability I found across all these experiments:

  • 2,623 output tok/s batched on 2x GPUs — the 1000+ tok/s range makes it a completely different game
  • 49 simultaneous 128K-token conversations in 192 GB VRAM
  • 86% IFEval with thinking disabled (sanity check: inference is correct)
  • Vision works natively through the same weights, no adapter needed
  • FP8 is quality-identical to BF16 (perplexity 9.59 vs 9.56) — 43% more throughput for free
  • Fits on a single 96 GB card with room to spare
  • Works really well in practice with opencode and other coding tools

For harder tasks: Qwen3.5-122B-A10B NVFP4

If you need more reasoning depth, Qwen3.5-122B-A10B in NVFP4 with --moe-backend marlin is a strong alternative — 102 tok/s single-stream (9.2ms/tok), best perplexity (7.48), 57.3% MATH Hard, and only 35.8 GiB/GPU leaving room for 28 concurrent 128K sessions. It also performs excellently as an agentic coding backbone with opencode and hermes-agent.

vllm serve Sehyo/Qwen3.5-122B-A10B-NVFP4 \
    --tensor-parallel-size 2 \
    --dtype auto \
    --max-model-len 131072 \
    --gpu-memory-utilization 0.90 \
    --moe-backend marlin \
    --default-chat-template-kwargs '{"enable_thinking": false}' \
    --enable-auto-tool-choice \
    --tool-call-parser qwen3_xml \
    --disable-custom-all-reduce \
    --port 8000