MXFP4

Reference Implementation

qwantize.mxfp4.reference.mxfp4_naive(W, dim=-1, return_dequant=False)[source]

Naive MXFP4 quantization: s = 2^(floor(log2(amax)) - 2) per block.

Parameters:
  • W – Input tensor. W.shape[dim] must be 16 or 32 (the block size).

  • dim – Dimension along which to quantize (default: -1).

  • return_dequant – If True, also return the dequantized tensor.

Returns:

(scales, quants) by default, or (scales, quants, dequant) if return_dequant is True.

  • scales: Per-block UE8M0 (power-of-2) scales. Shape is W.shape with dimension dim removed.

  • quants: Signed FP4 codebook values. Same shape as W.

  • dequant: quants * scales broadcast. Same shape as W.

qwantize.mxfp4.reference.mxfp4_optimal(W, dim=-1, return_dequant=False)[source]

Optimal MXFP4 quantization via bounded search over UE8M0 scales.

Same algorithm as nvfp4_optimal(), adapted for UE8M0 power-of-2 scales. Since consecutive UE8M0 scales differ by a factor of 2, the optimal is always within 1 step of naive.

Parameters:
  • W – Input tensor. W.shape[dim] must be 16 or 32 (the block size).

  • dim – Dimension along which to quantize (default: -1).

  • return_dequant – If True, also return the dequantized tensor.

Returns:

(scales, quants) by default, or (scales, quants, dequant) if return_dequant is True.

  • scales: Per-block optimal UE8M0 scales. Shape is W.shape with dimension dim removed.

  • quants: Signed FP4 codebook values. Same shape as W.

  • dequant: quants * scales broadcast. Same shape as W.

qwantize.mxfp4.reference.mxfp4_optimal_hessian(W, dim=-1, return_dequant=False, X=None, H_blocks=None)[source]

Hessian-aware optimal MXFP4 scale search.

Like mxfp4_optimal(), searches over UE8M0 scale candidates using SSE bounds for pruning, but selects the scale minimizing the Hessian-weighted error (x - sq)^T H (x - sq) instead of raw SSE.

Parameters:
  • W – Input tensor. W.shape[dim] must be 16 or 32 (the block size).

  • dim – Dimension along which to quantize (default: -1).

  • return_dequant – If True, also return the dequantized tensor.

  • X – Activation tensor of shape (T, K). H computed as X_j^T @ X_j.

  • H_blocks – Pre-computed block Hessians of shape (num_col_blocks, bs, bs). If provided, X is ignored.

Returns:

(scales, quants) by default, or (scales, quants, dequant) if return_dequant is True.

qwantize.mxfp4.reference.mxfp4_dequantize(scales, quants, dim=-1)[source]

Dequantize MXFP4: dequant = quants * scales.

Parameters:
  • scales – Per-block scales. Shape is the original W.shape with dimension dim removed.

  • quants – Signed FP4 codebook values. Same shape as the original W.

  • dim – Block dimension in quants (default: -1).

Returns:

Dequantized tensor with the same shape as quants.

qwantize.mxfp4.reference.build_ue8m0_scales(device='cpu')[source]

Return sorted tensor of all 254 positive UE8M0 scale values.

UE8M0 scale: 2^(e - 127) for e in {1, ..., 254}. (e=0 reserved for zero, e=255 reserved for NaN/Inf)

Parameters:

device – Torch device for the output tensor.

Returns:

Tensor of shape (254,) with sorted UE8M0 power-of-2 scales as float32.

qwantize.mxfp4.reference.fp4_quantize(x, s)[source]

Quantize to FP4 E2M1 codebook values given a per-block scale.

Maps each element to the nearest value in {0, 0.5, 1, 1.5, 2, 3, 4, 6} (with sign preserved).

Parameters:
  • x – Input tensor of shape (..., block_size).

  • s – Per-block scale of shape (..., 1), broadcastable to x.

Returns:

Signed codebook values with the same shape as x.

qwantize.mxfp4.reference.fp4_dequantize(quants, s)[source]

Dequantize FP4 codebook values back to float: dequant = quants * s.

Parameters:
  • quants – Signed codebook values of shape (..., block_size).

  • s – Per-block scale of shape (..., 1), broadcastable to quants.

Returns:

Dequantized tensor with the same shape as quants.

qwantize.mxfp4.reference.compute_block_sse(x, s)[source]

Compute per-block sum of squared quantization error.

Parameters:
  • x – Block values of shape (num_blocks, block_size).

  • s – Per-block scales of shape (num_blocks,) or (num_blocks, 1).

Returns:

Tensor of shape (num_blocks,) with the SSE for each block.

Triton Kernels

qwantize.mxfp4.kernels.mxfp4_naive_triton(W, dim=-1, return_dequant=False)[source]

Naive MXFP4 quantization using Triton kernel with inline PTX ASM.

GPU-accelerated version of mxfp4_naive().

Parameters:
  • W – Input tensor. W.shape[dim] must be 16 or 32 (the block size).

  • dim – Dimension along which to quantize (default: -1).

  • return_dequant – If True, also return the dequantized tensor.

Returns:

(scales, quants) by default, or (scales, quants, dequant) if return_dequant is True. See mxfp4_naive() for shape details.

qwantize.mxfp4.kernels.mxfp4_optimal_triton(W, dim=-1, return_dequant=False)[source]

Optimal MXFP4 quantization using Triton kernel with inline PTX ASM.

GPU-accelerated version of mxfp4_optimal(). Compares naive scale s0 with 2*s0 (one UE8M0 step up) and picks whichever has lower SSE.

Parameters:
  • W – Input tensor. W.shape[dim] must be 16 or 32 (the block size).

  • dim – Dimension along which to quantize (default: -1).

  • return_dequant – If True, also return the dequantized tensor.

Returns:

(scales, quants) by default, or (scales, quants, dequant) if return_dequant is True. See mxfp4_optimal() for shape details.

qwantize.mxfp4.kernels.mxfp4_naive_torch(W, dim=-1, return_dequant=False)[source]

Naive MXFP4 quantization using pure PyTorch operations.

Functionally identical to mxfp4_naive() but uses vectorized argmin over the full signed codebook instead of bucketize.

Parameters:
  • W – Input tensor. W.shape[dim] must be 16 or 32 (the block size).

  • dim – Dimension along which to quantize (default: -1).

  • return_dequant – If True, also return the dequantized tensor.

Returns:

(scales, quants) by default, or (scales, quants, dequant) if return_dequant is True. See mxfp4_naive() for shape details.

qwantize.mxfp4.kernels.mxfp4_optimal_torch(W, dim=-1, return_dequant=False)[source]

Optimal MXFP4 quantization using pure PyTorch operations.

Tries naive scale s0 and 2*s0, picks whichever has lower SSE.

Parameters:
  • W – Input tensor. W.shape[dim] must be 16 or 32 (the block size).

  • dim – Dimension along which to quantize (default: -1).

  • return_dequant – If True, also return the dequantized tensor.

Returns:

(scales, quants) by default, or (scales, quants, dequant) if return_dequant is True. See mxfp4_optimal() for shape details.

Constants

  • qwantize.mxfp4.Q_MAX = 6.0 – Maximum FP4 codebook value

  • qwantize.mxfp4.D_0 = 0.25 – Decision boundary for rounding to zero

  • qwantize.mxfp4.FP4_CODEBOOK = [0, 0.5, 1, 1.5, 2, 3, 4, 6] – Standard FP4 E2M1 codebook

  • qwantize.mxfp4.FP4_BOUNDARIES = [0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0] – Decision boundaries