MXFP4
Reference Implementation
- qwantize.mxfp4.reference.mxfp4_naive(W, dim=-1, return_dequant=False)[source]
Naive MXFP4 quantization:
s = 2^(floor(log2(amax)) - 2)per block.- Parameters:
W – Input tensor.
W.shape[dim]must be 16 or 32 (the block size).dim – Dimension along which to quantize (default: -1).
return_dequant – If
True, also return the dequantized tensor.
- Returns:
(scales, quants)by default, or(scales, quants, dequant)if return_dequant isTrue.scales: Per-block UE8M0 (power-of-2) scales. Shape is W.shape with dimension dim removed.
quants: Signed FP4 codebook values. Same shape as W.
dequant:
quants * scalesbroadcast. Same shape as W.
- qwantize.mxfp4.reference.mxfp4_optimal(W, dim=-1, return_dequant=False)[source]
Optimal MXFP4 quantization via bounded search over UE8M0 scales.
Same algorithm as
nvfp4_optimal(), adapted for UE8M0 power-of-2 scales. Since consecutive UE8M0 scales differ by a factor of 2, the optimal is always within 1 step of naive.- Parameters:
W – Input tensor.
W.shape[dim]must be 16 or 32 (the block size).dim – Dimension along which to quantize (default: -1).
return_dequant – If
True, also return the dequantized tensor.
- Returns:
(scales, quants)by default, or(scales, quants, dequant)if return_dequant isTrue.scales: Per-block optimal UE8M0 scales. Shape is W.shape with dimension dim removed.
quants: Signed FP4 codebook values. Same shape as W.
dequant:
quants * scalesbroadcast. Same shape as W.
- qwantize.mxfp4.reference.mxfp4_optimal_hessian(W, dim=-1, return_dequant=False, X=None, H_blocks=None)[source]
Hessian-aware optimal MXFP4 scale search.
Like
mxfp4_optimal(), searches over UE8M0 scale candidates using SSE bounds for pruning, but selects the scale minimizing the Hessian-weighted error(x - sq)^T H (x - sq)instead of raw SSE.- Parameters:
W – Input tensor.
W.shape[dim]must be 16 or 32 (the block size).dim – Dimension along which to quantize (default: -1).
return_dequant – If
True, also return the dequantized tensor.X – Activation tensor of shape
(T, K). H computed as X_j^T @ X_j.H_blocks – Pre-computed block Hessians of shape
(num_col_blocks, bs, bs). If provided, X is ignored.
- Returns:
(scales, quants)by default, or(scales, quants, dequant)if return_dequant isTrue.
- qwantize.mxfp4.reference.mxfp4_dequantize(scales, quants, dim=-1)[source]
Dequantize MXFP4:
dequant = quants * scales.- Parameters:
scales – Per-block scales. Shape is the original W.shape with dimension dim removed.
quants – Signed FP4 codebook values. Same shape as the original W.
dim – Block dimension in quants (default: -1).
- Returns:
Dequantized tensor with the same shape as quants.
- qwantize.mxfp4.reference.build_ue8m0_scales(device='cpu')[source]
Return sorted tensor of all 254 positive UE8M0 scale values.
UE8M0 scale:
2^(e - 127)fore in {1, ..., 254}. (e=0 reserved for zero, e=255 reserved for NaN/Inf)- Parameters:
device – Torch device for the output tensor.
- Returns:
Tensor of shape
(254,)with sorted UE8M0 power-of-2 scales as float32.
- qwantize.mxfp4.reference.fp4_quantize(x, s)[source]
Quantize to FP4 E2M1 codebook values given a per-block scale.
Maps each element to the nearest value in
{0, 0.5, 1, 1.5, 2, 3, 4, 6}(with sign preserved).- Parameters:
x – Input tensor of shape
(..., block_size).s – Per-block scale of shape
(..., 1), broadcastable to x.
- Returns:
Signed codebook values with the same shape as x.
- qwantize.mxfp4.reference.fp4_dequantize(quants, s)[source]
Dequantize FP4 codebook values back to float:
dequant = quants * s.- Parameters:
quants – Signed codebook values of shape
(..., block_size).s – Per-block scale of shape
(..., 1), broadcastable to quants.
- Returns:
Dequantized tensor with the same shape as quants.
- qwantize.mxfp4.reference.compute_block_sse(x, s)[source]
Compute per-block sum of squared quantization error.
- Parameters:
x – Block values of shape
(num_blocks, block_size).s – Per-block scales of shape
(num_blocks,)or(num_blocks, 1).
- Returns:
Tensor of shape
(num_blocks,)with the SSE for each block.
Triton Kernels
- qwantize.mxfp4.kernels.mxfp4_naive_triton(W, dim=-1, return_dequant=False)[source]
Naive MXFP4 quantization using Triton kernel with inline PTX ASM.
GPU-accelerated version of
mxfp4_naive().- Parameters:
W – Input tensor.
W.shape[dim]must be 16 or 32 (the block size).dim – Dimension along which to quantize (default: -1).
return_dequant – If
True, also return the dequantized tensor.
- Returns:
(scales, quants)by default, or(scales, quants, dequant)if return_dequant isTrue. Seemxfp4_naive()for shape details.
- qwantize.mxfp4.kernels.mxfp4_optimal_triton(W, dim=-1, return_dequant=False)[source]
Optimal MXFP4 quantization using Triton kernel with inline PTX ASM.
GPU-accelerated version of
mxfp4_optimal(). Compares naive scales0with2*s0(one UE8M0 step up) and picks whichever has lower SSE.- Parameters:
W – Input tensor.
W.shape[dim]must be 16 or 32 (the block size).dim – Dimension along which to quantize (default: -1).
return_dequant – If
True, also return the dequantized tensor.
- Returns:
(scales, quants)by default, or(scales, quants, dequant)if return_dequant isTrue. Seemxfp4_optimal()for shape details.
- qwantize.mxfp4.kernels.mxfp4_naive_torch(W, dim=-1, return_dequant=False)[source]
Naive MXFP4 quantization using pure PyTorch operations.
Functionally identical to
mxfp4_naive()but uses vectorizedargminover the full signed codebook instead ofbucketize.- Parameters:
W – Input tensor.
W.shape[dim]must be 16 or 32 (the block size).dim – Dimension along which to quantize (default: -1).
return_dequant – If
True, also return the dequantized tensor.
- Returns:
(scales, quants)by default, or(scales, quants, dequant)if return_dequant isTrue. Seemxfp4_naive()for shape details.
- qwantize.mxfp4.kernels.mxfp4_optimal_torch(W, dim=-1, return_dequant=False)[source]
Optimal MXFP4 quantization using pure PyTorch operations.
Tries naive scale
s0and2*s0, picks whichever has lower SSE.- Parameters:
W – Input tensor.
W.shape[dim]must be 16 or 32 (the block size).dim – Dimension along which to quantize (default: -1).
return_dequant – If
True, also return the dequantized tensor.
- Returns:
(scales, quants)by default, or(scales, quants, dequant)if return_dequant isTrue. Seemxfp4_optimal()for shape details.
Constants
qwantize.mxfp4.Q_MAX = 6.0– Maximum FP4 codebook valueqwantize.mxfp4.D_0 = 0.25– Decision boundary for rounding to zeroqwantize.mxfp4.FP4_CODEBOOK = [0, 0.5, 1, 1.5, 2, 3, 4, 6]– Standard FP4 E2M1 codebookqwantize.mxfp4.FP4_BOUNDARIES = [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0]– Decision boundaries