Is it possible to make smaller NVFP4 quant at 340-360GB to fit in 4x96gb?
Hi Is it possible to make smaller NVFP4 quant at 340-360GB to fit in 4x96gb? I've never done a quant before but willing to try. wondering if we can quantize more layers to get the size down a tad bit more?
You could try quantizing the indexer but my intuition says you probably don't want to. I think this is about as small as you can get with nvfp4 without really hurting model performance. If you give up on gpu acceleration you could go smaller though with llama.cpp style quantization.
It should work in vllm with sm100, unfortunately due to how nvidia decided to segment their consumer vs datacenter blackwell cards much of the code in triton/deep gem/etc doesn't properly support sm120. The vllm hackery was mostly straight forward but deep gemm (https://github.com/deepseek-ai/DeepGEMM) required extensive work to even get something working and is still a ways off from something I would try to get merged. This is why I only provided the cpu reference impl for validation and experimentation with this model. Hopefully with time sm120 (rtx pro 6000 blackwell) will get better support from projects like deepgemm/triton/vllm/sglang/etc
I uploaded https://hub.docker.com/repository/docker/eous/vllm-sm120/general which has my sm120 hacks, it is very mvp/research and will probably not work. Though just tested the model and with a smaller context you should be able to fit this model on 4 96gb gpu's.
@eousphoros almost!
NVCC compilation failed: /root/.cache/vllm/deep_gemm/cache/kernel.smxx_fp8_mqa_logits.6170cd6e0de7e861f56139277bd6b709/kernel.cu:2:10: fatal error: deep_gemm/impls/sm120_fp8_mqa_logits.cuh: No such file or directory 2 | #include <deep_gemm/impls/sm120_fp8_mqa_logits.cuh> | ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ compilation terminated.
is this maybe because i used the awq variation instead of nvfp4? its a bit smaller thats why
edit: oooh i need to install DeeeGEMM I see ok i need also to edit the install script to use command python3 vs just python and add --force-reinstall
ok Successfully installed deep-gemm-2.2.0+local
edit2: still get NVCC compilation failed: /root/.cache/vllm/deep_gemm/cache/kernel.smxx_fp8_mqa_logits.6170cd6e0de7e861f56139277bd6b709/kernel.cu:2:10: fatal error: deep_gemm/impls/sm120_fp8_mqa_logits.cuh: No such file or directory
Ah woops, forgot to copy the decode kernel into the container. I pushed a new container up. Also no idea if this will work with AWQ, it barely works with my nvfp4 quant.
(APIServer pid=1) INFO 12-05 16:23:56 [loggers.py:248] Engine 000: Avg prompt throughput: 0.7 tokens/s, Avg generation throughput: 0.1 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
Don't expect this to be fast, but it is faster then cpu inference.
Run 2: ✅ OK (0/16777216) | backend=flashinfer-cutlass, uva=no, pre_uva_nan=0
Sample values: [-0.00078582763671875, 0.0001964569091796875, 0.000957489013671875, 9.441375732421875e-05, -0.0003261566162109375]
Stats: min=-0.3125, max=0.3125, mean=0.0000
Scales: input_scale={'value': 1.0, 'device': 'cuda:0', 'ptr': 123644922561536}
Checkpoint: input_scale=False, weight_scale=True, weight_scale_2=True
Think I am getting close to having vllm sorted out.
Clean FP4 GEMM run—nice to see the cutlass weight-decode path is now NaN-free.
The sparse-attention kernel (FlashMLA pre-fill) is a separate tile; its 50 % zeros / −inf / 16-elem-aligned pattern still matches a tile-granularity under-flow guard—whatever the exact symbol (sum_exp, den == 0, all_masked, …) the fix is the same: move from whole-tile zero-fill to per-lane scale and the zeros will drop to 0 %. 🤞
Disassembling the closed-source .so shows a REDUX (warp-sum) immediately followed by STL.128 [R1+offset], RZ – the kernel deliberately stores 128-bit zeros for an entire 16-element tile whenever the denominator underflows. That produces the exact 50 % zeros / −inf in max_logits we measured for every d_v ≥ 32.
Fix
Replace the whole-tile memset with per-lane scaling:
out[i] = acc_v[i] * (sum == 0 ? 0 : 1 / sum)
Only the masked lanes become zero; valid lanes keep their correct value, eliminating the 50 % pattern without breaking numerical safety.
So one of the facepalm issues Ive discovered is modelopt in vllm is ONLY w4a4 and this model is w4a16 so the input scales are basically random noise. The saga continues.
So one of the facepalm issues Ive discovered is modelopt in vllm is ONLY w4a4 and this model is w4a16 so the input scales are basically random noise. The saga continues.
@eousphoros can you share fwd.cu so i can check for the bug? I have tested on both sm120 and sm90 in your docker image and look:
sm90:
Details:
Minimum Configuration: ✅ -inf=False, zeros=0.0%
Small Batch: ✅ -inf=False, zeros=0.0%
Medium Sequence: ✅ -inf=False, zeros=0.0%
sm120:
Details:
Minimum Configuration: ✅ -inf=False, zeros=50.0%
Small Batch: ✅ -inf=False, zeros=50.0%
Medium Sequence: ✅ -inf=False, zeros=50.0%
we have 50% zeros on sm120
=== DEEP ZERO ANALYSIS FOR FLASHMLA PORT ===
GPU: NVIDIA RTX PRO 6000 Blackwell Workstation Edition
Architecture: SM120
✅ FlashMLA loaded
Running comprehensive tests...
Kernel constants: B_H=64, B_TOPK=64
================================================================================
TEST: MINIMAL
Params: s_q=1, h_q=64, s_kv=256, topk=128
Tensor shapes:
q: torch.Size([1, 64, 576]), norm=19.108
kv: torch.Size([256, 1, 576]), norm=38.478
indices: torch.Size([1, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
❌❌❌ CRITICAL: Found -inf in max_logits (3 positions)
This causes zeros in output!
❌❌❌ OUTPUT ZEROS: 16384/32768 (50.0%)
📊 ZERO PATTERN ANALYSIS:
Zeros per sequence (s_q): [16384]
Zeros per head (h_q): [256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256
256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256
256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256 256
256 256 256 256 256 256 256 256 256 256]
Zeros in first 256 dims: 8192
Zeros in last 256 dims: 8192
Mixed zero pattern
📈 MAX_LOGITS ANALYSIS:
Range: [-inf, 7060358122880455717866798789141987328.000000]
Mean: -inf, Std: nan
⚠️ 3 values < -10.0
📊 LSE ANALYSIS:
Range: [-inf, 7060358122880455717866798789141987328.000000]
📊 OUTPUT STATISTICS:
Output range: [nan, nan]
Output mean: nan, std: nan
❌❌❌ FOUND NaN IN OUTPUT!
================================================================================
🚨🚨🚨 50% ZEROS BUG DETECTED! 🚨🚨🚨
POSSIBLE CAUSES for 50% zeros:
- ❌ Warp/thread divergence in SM120 kernel
- ❌ Incorrect shared memory layout
- ❌ Wrong tensor core configuration
- ❌ Half the threads/warps not executing
- ❌ Memory access pattern mismatch
DIAGNOSTIC CHECKS:
Zero positions in first output: [16 17 18 19 20 21 22 23 24 25]...
max_logits for first sequence: [9.80560730e+33 6.35428123e+32 3.37208696e-02 3.50317731e-02
4.45985422e-02 2.94296816e-02 1.21655772e+33 7.06035812e+36
2.04398166e-02 2.14101858e-02 2.52381824e-02 1.96358506e-02
1.30142569e+34 2.62514475e+35 3.70108150e-02 2.75616795e-02
3.68889682e-02 3.81367281e-02 1.11068119e+35 4.26542056e+27
1.90967824e-02 4.40312810e-02 2.16482412e-02 3.41578498e-02
2.54433506e+26 7.73387411e+32 2.93120816e-02 3.45220268e-02
2.87235789e-02 3.64855230e-02 1.78032069e+34 1.68245522e+34
2.19204146e-02 3.65953557e-02 1.80366822e-02 2.42102426e-02
-inf -inf 2.26458944e-02 4.03770730e-02
2.28515510e-02 2.65495256e-02 7.94248354e+31 4.80606429e+32
1.50184585e-02 3.43337618e-02 3.70853320e-02 2.57805903e-02
9.47893461e+34 -inf 2.28228103e-02 6.06317595e-02
4.62973192e-02 3.49170715e-02 2.96597891e-02 3.28789949e-02
3.38554978e-02 3.88729051e-02 3.95172797e-02 2.95421686e-02
2.66985707e-02 1.33980261e-02 2.05924455e-02 3.29767801e-02]
================================================================================
TEST: SMALL
Params: s_q=4, h_q=128, s_kv=1024, topk=256
Tensor shapes:
q: torch.Size([4, 128, 576]), norm=54.296
kv: torch.Size([1024, 1, 576]), norm=76.692
indices: torch.Size([4, 1, 256])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
❌❌❌ CRITICAL: Found -inf in max_logits (20 positions)
This causes zeros in output!
❌❌❌ OUTPUT ZEROS: 180228/262144 (68.8%)
📊 ZERO PATTERN ANALYSIS:
Zeros per sequence (s_q): [32772 49152 49152 49152]
Zeros per head (h_q): [1696 1792 1696 1792 1792 1792 1024 1024 1024 1024 1024 1024 1792 1792
1792 1024 1600 1792 1024 1024 1024 1024 1024 1024 1792 1792 1792 1792
1696 1792 1024 1024 1024 1024 1024 1024 1792 1793 1792 1792 1792 1792
1024 1024 1024 1024 1024 1024 1792 1792 1792 1024 1792 1792 1024 1024
1024 1024 1600 1792 1024 1696 1024 1792 1024 1024 1024 1024 1792 1792
1793 1792 1696 1792 1024 1024 1024 1024 1792 1793 1792 1792 1504 1792
1024 1024 1024 1024 1792 1792 1792 1792 1696 1024 1024 1024 1024 1024
1696 1792 1792 1792 1600 1792 1024 1024 1024 1024 1792 1600 1792 1792
1792 1600 1024 1024 1024 1024 1600 1792 1792 1792 1793 1600 1024 1024
1024 1024]
Zeros in first 256 dims: 90112
Zeros in last 256 dims: 90116
Mixed zero pattern
📈 MAX_LOGITS ANALYSIS:
Range: [-inf, 9678993406145318093485196971912200192.000000]
Mean: -inf, Std: nan
⚠️ 20 values < -10.0
📊 LSE ANALYSIS:
Range: [-inf, 9678993406145318093485196971912200192.000000]
📊 OUTPUT STATISTICS:
Output range: [nan, nan]
Output mean: nan, std: nan
❌❌❌ FOUND NaN IN OUTPUT!
================================================================================
TEST: MEDIUM
Params: s_q=8, h_q=64, s_kv=2048, topk=128
Tensor shapes:
q: torch.Size([8, 64, 576]), norm=54.296
kv: torch.Size([2048, 1, 576]), norm=108.569
indices: torch.Size([8, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
❌❌❌ CRITICAL: Found -inf in max_logits (56 positions)
This causes zeros in output!
❌❌❌ OUTPUT ZEROS: 219553/262144 (83.8%)
📊 ZERO PATTERN ANALYSIS:
Zeros per sequence (s_q): [16385 29024 29024 29024 29024 29024 29024 29024]
Zeros per head (h_q): [2944 2048 3168 3840 3840 3168 3840 2944 3392 3392 3840 3840 3840 3840
3616 3840 3840 3616 3392 3392 3840 2048 3840 3392 3840 3616 3840 3840
3392 3840 2944 2048 3840 2944 3840 2048 3392 3392 3840 2048 3840 3840
3840 2944 3616 3392 3840 3392 3840 3840 3840 2048 3840 2944 3616 3840
3840 3840 3840 2048 2945 3392 3840 3840]
Zeros in first 256 dims: 109776
Zeros in last 256 dims: 109777
Mixed zero pattern
📈 MAX_LOGITS ANALYSIS:
Range: [-inf, 8896381419781215651424374320602808320.000000]
Mean: -inf, Std: nan
⚠️ 56 values < -10.0
📊 LSE ANALYSIS:
Range: [-inf, 8896381419781215651424374320602808320.000000]
📊 OUTPUT STATISTICS:
Output range: [nan, nan]
Output mean: nan, std: nan
❌❌❌ FOUND NaN IN OUTPUT!
================================================================================
TEST: LARGE_HEADS
Params: s_q=2, h_q=256, s_kv=512, topk=128
Tensor shapes:
q: torch.Size([2, 256, 576]), norm=54.296
kv: torch.Size([512, 1, 576]), norm=54.277
indices: torch.Size([2, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
❌❌❌ CRITICAL: Found -inf in max_logits (54 positions)
This causes zeros in output!
❌❌❌ OUTPUT ZEROS: 166538/262144 (63.5%)
📊 ZERO PATTERN ANALYSIS:
Zeros per sequence (s_q): [ 65546 100992]
Zeros per head (h_q): [768 608 768 512 768 705 736 672 704 768 512 672 512 512 704 768 768 768
768 768 768 736 768 704 512 768 512 512 512 768 704 768 768 768 512 768
768 512 673 736 512 512 512 512 672 672 608 768 768 672 512 768 704 768
768 672 512 512 512 512 704 640 769 512 704 768 768 768 768 737 512 512
512 512 512 512 736 768 768 768 768 512 736 768 768 512 512 768 512 512
512 512 768 576 736 512 768 768 768 641 768 768 768 768 512 512 512 512
768 768 768 736 768 768 768 768 768 768 704 512 512 512 512 512 736 768
768 512 768 768 768 768 768 768 512 512 512 512 512 512 736 768 672 768
768 512 672 704 704 768 737 768 512 512 512 512 768 704 768 512 768 736
512 705 768 768 768 736 512 512 512 512 768 704 768 768 768 768 768 704
768 768 768 768 512 512 512 512 512 512 672 512 672 512 512 512 736 512
512 512 736 768 768 641 704 768 512 512 512 512 705 768 512 512 768 512
768 512 512 736 768 512 512 512 512 768 768 768 512 512 704 512 768 768
512 705 768 768 768 768 512 512 512 512 512 768 768 768 512 512 512 512
512 512 512 736]
Zeros in first 256 dims: 83272
Zeros in last 256 dims: 83266
Mixed zero pattern
📈 MAX_LOGITS ANALYSIS:
Range: [-inf, 13776871091051405123495222820638556160.000000]
Mean: -inf, Std: nan
⚠️ 54 values < -10.0
📊 LSE ANALYSIS:
Range: [-inf, 13776871091051405123495222820638556160.000000]
📊 OUTPUT STATISTICS:
Output range: [nan, nan]
Output mean: nan, std: nan
❌❌❌ FOUND NaN IN OUTPUT!
================================================================================
TEST: LARGE_SEQ
Params: s_q=16, h_q=64, s_kv=4096, topk=128
Tensor shapes:
q: torch.Size([16, 64, 576]), norm=76.801
kv: torch.Size([4096, 1, 576]), norm=153.542
indices: torch.Size([16, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
❌❌❌ CRITICAL: Found -inf in max_logits (64 positions)
This causes zeros in output!
❌❌❌ OUTPUT ZEROS: 355778/524288 (67.9%)
📊 ZERO PATTERN ANALYSIS:
Zeros per sequence (s_q): [16386 29760 29760 29760 29760 29760 29760 29760 16384 16384 16384 16384
16384 16384 16384 16384]
Zeros per head (h_q): [5888 5888 5888 5888 5888 5664 5888 5888 5888 5888 4096 5888 5888 4992
5888 5664 5889 5888 4096 5888 5888 5664 4992 5888 5888 5888 5664 4992
5888 5888 5888 5888 5888 5888 5888 4096 5888 5888 4096 5888 4096 5216
5888 4096 5888 5441 4768 5440 5888 5888 5888 5888 5440 4096 5888 5888
5888 5888 4096 5888 5888 5888 5888 5888]
Zeros in first 256 dims: 177889
Zeros in last 256 dims: 177889
Mixed zero pattern
📈 MAX_LOGITS ANALYSIS:
Range: [-inf, 10161962580404572468423134157996032000.000000]
Mean: -inf, Std: nan
⚠️ 64 values < -10.0
📊 LSE ANALYSIS:
Range: [-inf, 10161962580404572468423134157996032000.000000]
📊 OUTPUT STATISTICS:
Output range: [nan, nan]
Output mean: nan, std: nan
❌❌❌ FOUND NaN IN OUTPUT!
================================================================================
SUMMARY FOR SM120
Test -inf Zeros Status
MINIMAL ❌ Yes ❌❌❌ 50.0% 🚨
SMALL ❌ Yes ❌ 68.8%
MEDIUM ❌ Yes ❌ 83.8%
LARGE_HEADS ❌ Yes ❌ 63.5%
LARGE_SEQ ❌ Yes ❌ 67.9%
✅ Detailed results saved to: /tmp/flashmla_SM120_deep_analysis.json
=== DEEP ZERO ANALYSIS FOR FLASHMLA PORT ===
GPU: NVIDIA H200
Architecture: SM90
✅ FlashMLA loaded
Running comprehensive tests...
Kernel constants: B_H=64, B_TOPK=64
================================================================================
TEST: MINIMAL
Params: s_q=1, h_q=64, s_kv=256, topk=128
Tensor shapes:
q: torch.Size([1, 64, 576]), norm=19.108
kv: torch.Size([256, 1, 576]), norm=38.478
indices: torch.Size([1, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
✅ No -inf in max_logits
✅ OUTPUT ZEROS: 0/32768 (0.0%)
📈 MAX_LOGITS ANALYSIS:
Range: [0.059856, 0.141630]
Mean: 0.087430, Std: 0.015346
📊 LSE ANALYSIS:
Range: [6.993042, 7.006902]
📊 OUTPUT STATISTICS:
Output range: [-0.028931, 0.033691]
Output mean: 0.000287, std: 0.010222
================================================================================
TEST: SMALL
Params: s_q=4, h_q=128, s_kv=1024, topk=256
Tensor shapes:
q: torch.Size([4, 128, 576]), norm=54.347
kv: torch.Size([1024, 1, 576]), norm=76.623
indices: torch.Size([4, 1, 256])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
✅ No -inf in max_logits
✅ OUTPUT ZEROS: 0/262144 (0.0%)
📈 MAX_LOGITS ANALYSIS:
Range: [0.065561, 0.150781]
Mean: 0.096242, Std: 0.013935
📊 LSE ANALYSIS:
Range: [7.992278, 8.007222]
📊 OUTPUT STATISTICS:
Output range: [-0.021118, 0.025024]
Output mean: -0.000004, std: 0.007091
================================================================================
TEST: MEDIUM
Params: s_q=8, h_q=64, s_kv=2048, topk=128
Tensor shapes:
q: torch.Size([8, 64, 576]), norm=54.347
kv: torch.Size([2048, 1, 576]), norm=108.510
indices: torch.Size([8, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
✅ No -inf in max_logits
✅ OUTPUT ZEROS: 0/262144 (0.0%)
📈 MAX_LOGITS ANALYSIS:
Range: [0.052895, 0.142099]
Mean: 0.088744, Std: 0.015318
📊 LSE ANALYSIS:
Range: [6.991348, 7.012632]
📊 OUTPUT STATISTICS:
Output range: [-0.035889, 0.034180]
Output mean: 0.000061, std: 0.009220
================================================================================
TEST: LARGE_HEADS
Params: s_q=2, h_q=256, s_kv=512, topk=128
Tensor shapes:
q: torch.Size([2, 256, 576]), norm=54.347
kv: torch.Size([512, 1, 576]), norm=54.252
indices: torch.Size([2, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
✅ No -inf in max_logits
✅ OUTPUT ZEROS: 0/262144 (0.0%)
📈 MAX_LOGITS ANALYSIS:
Range: [0.044454, 0.157338]
Mean: 0.088348, Std: 0.015180
📊 LSE ANALYSIS:
Range: [6.988080, 7.011192]
📊 OUTPUT STATISTICS:
Output range: [-0.030884, 0.029419]
Output mean: -0.000310, std: 0.009618
================================================================================
TEST: LARGE_SEQ
Params: s_q=16, h_q=64, s_kv=4096, topk=128
Tensor shapes:
q: torch.Size([16, 64, 576]), norm=76.818
kv: torch.Size([4096, 1, 576]), norm=153.521
indices: torch.Size([16, 1, 128])
Running kernel...
🔍 DEEP ZERO ANALYSIS:
✅ No -inf in max_logits
✅ OUTPUT ZEROS: 0/524288 (0.0%)
📈 MAX_LOGITS ANALYSIS:
Range: [0.057230, 0.155825]
Mean: 0.089577, Std: 0.014456
📊 LSE ANALYSIS:
Range: [6.991382, 7.008635]
📊 OUTPUT STATISTICS:
Output range: [-0.033936, 0.034668]
Output mean: 0.000122, std: 0.008953
================================================================================
SUMMARY FOR SM90
Test -inf Zeros Status
MINIMAL ✅ No ✅ 0%
SMALL ✅ No ✅ 0%
MEDIUM ✅ No ✅ 0%
LARGE_HEADS ✅ No ✅ 0%
LARGE_SEQ ✅ No ✅ 0%
✅ Detailed results saved to: /tmp/flashmla_SM90_deep_analysis.json
lets go!!!!! copying build/lib.linux-x86_64-cpython-312/flash_mla/cuda.cpython-312-x86_64-linux-gnu.so -> flash_mla