Source code for qwantize.nvfp4.reference

import torch

@torch.compiler.disable
def _fp8_e4m3_snap(x):
    """Snap float32 values to nearest FP8 E4M3 representable value.

    Excluded from torch.compile because inductor emits float8e4nv which
    is unsupported on pre-Ada GPUs.
    """
    return x.to(torch.float8_e4m3fn).to(torch.float32)


# FP4 E2M1 codebook (actual values)
FP4_CODEBOOK = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0])
# Decision boundaries: midpoints between consecutive codebook values
FP4_BOUNDARIES = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0])
Q_MAX = 6.0
D_0 = 0.25  # decision boundary for rounding to zero


@torch.compiler.disable
def build_fp8_e4m3_scales(device="cpu"):
    """Return sorted tensor of all 126 positive FP8 E4M3 representable values.

    Args:
        device: Torch device for the output tensor.

    Returns:
        Tensor of shape ``(126,)`` with sorted positive FP8 E4M3 values as float32.
    """
    all_bytes = torch.arange(256, dtype=torch.uint8, device=device)
    fp8_vals = all_bytes.view(torch.float8_e4m3fn).to(torch.float32)
    pos = fp8_vals[(fp8_vals > 0) & (~fp8_vals.isnan())]
    return pos.unique().sort().values


[docs] def fp4_quantize(x, s): """Quantize to FP4 E2M1 codebook values given a per-block scale. Maps each element to the nearest value in ``{0, 0.5, 1, 1.5, 2, 3, 4, 6}`` (with sign preserved). Args: x: Input tensor of shape ``(..., block_size)``. s: Per-block scale of shape ``(..., 1)``, broadcastable to *x*. Returns: Signed codebook values with the same shape as *x*. """ boundaries = FP4_BOUNDARIES.to(x.device) codebook = FP4_CODEBOOK.to(x.device) signs = x.sign() y = x.abs() / s # normalized magnitude bucket_idx = torch.bucketize(y, boundaries) q_mag = codebook[bucket_idx] return signs * q_mag
[docs] def fp4_dequantize(quants, s): """Dequantize FP4 codebook values back to float: ``dequant = quants * s``. Args: quants: Signed codebook values of shape ``(..., block_size)``. s: Per-block scale of shape ``(..., 1)``, broadcastable to *quants*. Returns: Dequantized tensor with the same shape as *quants*. """ return quants * s
[docs] def compute_block_sse(x, s): """Compute per-block sum of squared quantization error. Args: x: Block values of shape ``(num_blocks, block_size)``. s: Per-block scales of shape ``(num_blocks,)`` or ``(num_blocks, 1)``. Returns: Tensor of shape ``(num_blocks,)`` with the SSE for each block. """ if s.dim() == 1: s = s.unsqueeze(-1) quants = fp4_quantize(x, s) dq = fp4_dequantize(quants, s) return (x - dq).pow(2).sum(dim=-1)
[docs] def nvfp4_naive(W, dim=-1, return_dequant=False): """Naive NVFP4 quantization: ``s = FP8_E4M3(max|x_i| / 6)`` per block. Args: W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size). dim: Dimension along which to quantize (default: -1). return_dequant: If ``True``, also return the dequantized tensor. Returns: ``(scales, quants)`` by default, or ``(scales, quants, dequant)`` if *return_dequant* is ``True``. - **scales**: Per-block FP8 E4M3 scales. Shape is *W.shape* with dimension *dim* removed. - **quants**: Signed FP4 codebook values. Same shape as *W*. - **dequant**: ``quants * scales`` broadcast. Same shape as *W*. """ dim = dim % W.ndim block_size = W.shape[dim] assert block_size in (16, 32) x = W.float().movedim(dim, -1) batch_shape = x.shape[:-1] x = x.reshape(-1, block_size) amax = x.abs().amax(dim=-1) s_cont = (amax / Q_MAX).clamp(min=1e-12) scales = _fp8_e4m3_snap(s_cont) quants = fp4_quantize(x, scales.unsqueeze(-1)) result = ( scales.reshape(*batch_shape), quants.reshape(*batch_shape, block_size).movedim(-1, dim), ) if return_dequant: dq = fp4_dequantize(quants, scales.unsqueeze(-1)) result = result + (dq.reshape(*batch_shape, block_size).movedim(-1, dim),) return result
[docs] def nvfp4_optimal(W, dim=-1, return_dequant=False): """Optimal NVFP4 quantization via bounded search over FP8 E4M3 scales. Uses clipping and dead-zone bounds to reduce the search from 126 FP8 candidates to ~4-8, with a fast-fail clipping check per candidate. See :doc:`../optimal_scale_search` for the algorithm. Args: W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size). dim: Dimension along which to quantize (default: -1). return_dequant: If ``True``, also return the dequantized tensor. Returns: ``(scales, quants)`` by default, or ``(scales, quants, dequant)`` if *return_dequant* is ``True``. - **scales**: Per-block optimal FP8 E4M3 scales. Shape is *W.shape* with dimension *dim* removed. - **quants**: Signed FP4 codebook values. Same shape as *W*. - **dequant**: ``quants * scales`` broadcast. Same shape as *W*. """ dim = dim % W.ndim block_size = W.shape[dim] assert block_size in (16, 32) x = W.float().movedim(dim, -1) batch_shape = x.shape[:-1] x = x.reshape(-1, block_size) N = x.shape[0] # total number of blocks all_scales = build_fp8_e4m3_scales(device=x.device) # Step 1: Baseline (naive) scale and error amax = x.abs().amax(dim=-1) # (N,) s_cont = (amax / Q_MAX).clamp(min=1e-12) s0 = _fp8_e4m3_snap(s_cont) # (N,) E0 = compute_block_sse(x, s0) # (N,) best_s = s0.clone() best_E = E0.clone() # Step 2: Edge case - noise blocks (sum(x^2) <= E0) total_energy = x.pow(2).sum(dim=-1) # (N,) noise_mask = total_energy <= E0 # Step 3: Compute bounds sqrt_E0 = E0.sqrt() s_min = ((amax - sqrt_E0) / Q_MAX).clamp(min=0.0) # (N,) # Upper bound: sort |x| ascending, cumsum of squares, find k* sorted_abs, _ = x.abs().sort(dim=-1) # (N, block_size) cumsum_sq = sorted_abs.pow(2).cumsum(dim=-1) # (N, block_size) k_star = (cumsum_sq <= E0.unsqueeze(-1)).sum(dim=-1) # (N,) # Blocks where all elements are "affordable" to quantize to zero -> noise noise_mask = noise_mask | (k_star >= block_size) k_star_idx = k_star.clamp(max=block_size - 1) y_k_plus_1 = sorted_abs.gather(dim=-1, index=k_star_idx.unsqueeze(-1)).squeeze(-1) s_max = y_k_plus_1 / D_0 # (N,) # Step 4: Bounded search over candidate scales active = ~noise_mask if active.any(): x_active = x[active] s_min_a = s_min[active] s_max_a = s_max[active] best_E_a = best_E[active].clone() best_s_a = best_s[active].clone() for s_val in all_scales: s_f = s_val.item() # Per-block range check in_range = (s_f >= s_min_a) & (s_f <= s_max_a) if not in_range.any(): continue # Fast-fail: clipping error H(s) = sum(max(|x_i| - 6*s, 0)^2) H_s = (x_active.abs() - Q_MAX * s_f).clamp(min=0).pow(2).sum(dim=-1) # Only evaluate blocks in range and passing fast-fail evaluate = in_range & (H_s < best_E_a) if not evaluate.any(): continue # Full SSE computation s_broadcast = torch.full( (x_active.shape[0], 1), s_f, device=x_active.device, dtype=x_active.dtype ) E_s = compute_block_sse(x_active, s_broadcast.squeeze(-1)) # Update best where improved improved = evaluate & (E_s < best_E_a) best_E_a[improved] = E_s[improved] best_s_a[improved] = s_f best_E[active] = best_E_a best_s[active] = best_s_a # Final quantization with optimal scales quants = fp4_quantize(x, best_s.unsqueeze(-1)) result = ( best_s.reshape(*batch_shape), quants.reshape(*batch_shape, block_size).movedim(-1, dim), ) if return_dequant: dq = fp4_dequantize(quants, best_s.unsqueeze(-1)) result = result + (dq.reshape(*batch_shape, block_size).movedim(-1, dim),) return result
def _compute_block_hessian_error(x, s, H_blocks, M_dim, num_col_blocks, block_size): """Compute per-block Hessian-weighted quantization error. Args: x: Block values of shape ``(N, block_size)`` where ``N = M_dim * num_col_blocks``. s: Per-block scales of shape ``(N,)``. H_blocks: Block Hessians of shape ``(num_col_blocks, block_size, block_size)``. M_dim: Number of rows (M dimension). num_col_blocks: Number of column blocks. block_size: Block size. Returns: Tensor of shape ``(N,)`` with per-block Hessian-weighted error. """ quants = fp4_quantize(x, s.unsqueeze(-1)) dq = fp4_dequantize(quants, s.unsqueeze(-1)) r = x - dq # (N, bs) # Reshape to (M, ncb, bs) for einsum with H (ncb, bs, bs) r_3d = r.reshape(M_dim, num_col_blocks, block_size) Hr = torch.einsum("jab,mjb->mja", H_blocks, r_3d) # (M, ncb, bs) E_H = (r_3d * Hr).sum(dim=-1) # (M, ncb) return E_H.reshape(-1) # (N,)
[docs] def nvfp4_optimal_hessian(W, dim=-1, return_dequant=False, X=None, H_blocks=None): """Hessian-aware optimal NVFP4 scale search. Like :func:`nvfp4_optimal`, searches over FP8 E4M3 scale candidates using SSE bounds for pruning, but selects the scale minimizing the Hessian-weighted error ``(x - sq)^T H (x - sq)`` instead of raw SSE. This directly minimizes each block's contribution to the output error ``||W_q X - WX||_F^2``. See :doc:`../hessian_scale_search` for the math. Args: W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size). dim: Dimension along which to quantize (default: -1). return_dequant: If ``True``, also return the dequantized tensor. X: Activation tensor of shape ``(T, K)``. H computed as X_j^T @ X_j. H_blocks: Pre-computed block Hessians of shape ``(num_col_blocks, bs, bs)``. If provided, X is ignored. Returns: ``(scales, quants)`` by default, or ``(scales, quants, dequant)`` if *return_dequant* is ``True``. - **scales**: Per-block optimal FP8 E4M3 scales. Shape is *W.shape* with dimension *dim* removed. - **quants**: Signed FP4 codebook values. Same shape as *W*. - **dequant**: ``quants * scales`` broadcast. Same shape as *W*. """ assert X is not None or H_blocks is not None, "X or H_blocks required" dim = dim % W.ndim block_size = W.shape[dim] assert block_size in (16, 32) x = W.float().movedim(dim, -1) batch_shape = x.shape[:-1] x = x.reshape(-1, block_size) N = x.shape[0] # --- Block Hessians --- if H_blocks is not None: H = H_blocks.to(x.device) num_col_blocks = H.shape[0] else: K_dim = X.shape[1] num_col_blocks = K_dim // block_size H = torch.empty(num_col_blocks, block_size, block_size, device=x.device) batch_t = 8192 for j in range(num_col_blocks): acc = torch.zeros(block_size, block_size, device=x.device) for t0 in range(0, X.shape[0], batch_t): Xj = X[t0 : t0 + batch_t, j * block_size : (j + 1) * block_size].float() acc.addmm_(Xj.T, Xj) H[j] = acc M_dim = N // num_col_blocks assert N == M_dim * num_col_blocks all_scales = build_fp8_e4m3_scales(device=x.device) # Step 1: Baseline (naive) scale and SSE error (for bounding) amax = x.abs().amax(dim=-1) # (N,) s_cont = (amax / Q_MAX).clamp(min=1e-12) s0 = _fp8_e4m3_snap(s_cont) # (N,) E0_sse = compute_block_sse(x, s0) # (N,) — SSE for bounding E0_H = _compute_block_hessian_error(x, s0, H, M_dim, num_col_blocks, block_size) best_s = s0.clone() best_E = E0_H.clone() # Step 2: Edge case - noise blocks (sum(x^2) <= E0_sse) total_energy = x.pow(2).sum(dim=-1) noise_mask = total_energy <= E0_sse # Step 3: Compute SSE-based bounds (still valid for pruning) sqrt_E0 = E0_sse.sqrt() s_min = ((amax - sqrt_E0) / Q_MAX).clamp(min=0.0) sorted_abs, _ = x.abs().sort(dim=-1) cumsum_sq = sorted_abs.pow(2).cumsum(dim=-1) k_star = (cumsum_sq <= E0_sse.unsqueeze(-1)).sum(dim=-1) noise_mask = noise_mask | (k_star >= block_size) k_star_idx = k_star.clamp(max=block_size - 1) y_k_plus_1 = sorted_abs.gather(dim=-1, index=k_star_idx.unsqueeze(-1)).squeeze(-1) s_max = y_k_plus_1 / D_0 # Step 4: Bounded search, evaluate Hessian error active = ~noise_mask if active.any(): # We need 3D indexing for Hessian error, so work with all blocks # but mask inactive ones best_E_a = best_E[active].clone() best_s_a = best_s[active].clone() s_min_a = s_min[active] s_max_a = s_max[active] x_active = x[active] # Map active indices to (m, j) for Hessian lookup active_indices = active.nonzero(as_tuple=True)[0] # (N_active,) active_j = active_indices % num_col_blocks # column block index for s_val in all_scales: s_f = s_val.item() # Per-block range check (SSE bounds) in_range = (s_f >= s_min_a) & (s_f <= s_max_a) if not in_range.any(): continue # Fast-fail: SSE clipping error H_s = (x_active.abs() - Q_MAX * s_f).clamp(min=0).pow(2).sum(dim=-1) evaluate = in_range & (H_s < best_E_a * 10) # looser fast-fail for H-error if not evaluate.any(): continue # Compute Hessian-weighted error for active blocks quants_s = fp4_quantize(x_active, torch.tensor(s_f, device=x.device)) dq_s = fp4_dequantize(quants_s, torch.tensor(s_f, device=x.device)) r = x_active - dq_s # (N_active, bs) # Hr = H[j] @ r for each active block H_active = H[active_j] # (N_active, bs, bs) Hr = torch.bmm(H_active, r.unsqueeze(-1)).squeeze(-1) # (N_active, bs) E_H_s = (r * Hr).sum(dim=-1) # (N_active,) improved = evaluate & (E_H_s < best_E_a) best_E_a[improved] = E_H_s[improved] best_s_a[improved] = s_f best_E[active] = best_E_a best_s[active] = best_s_a # Final quantization with optimal scales quants = fp4_quantize(x, best_s.unsqueeze(-1)) result = ( best_s.reshape(*batch_shape), quants.reshape(*batch_shape, block_size).movedim(-1, dim), ) if return_dequant: dq = fp4_dequantize(quants, best_s.unsqueeze(-1)) result = result + (dq.reshape(*batch_shape, block_size).movedim(-1, dim),) return result
def nvfp4_admm(W, dim=-1, return_dequant=False, X=None, n_outer=3, n_inner=10): """ADMM-based NVFP4 quantization using input Hessian. Starts from the Hessian-aware optimal scale found by :func:`nvfp4_optimal_hessian`, then uses ADMM to refine the FP4 quantization values by minimizing the Hessian-weighted error ``(x - sq)^T H (x - sq)`` per block, where ``H = X_block^T @ X_block`` captures input activation correlations. This further reduces output error beyond :func:`nvfp4_optimal` at the cost of slightly higher weight error. Args: W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size). dim: Dimension along which to quantize (default: -1). return_dequant: If ``True``, also return the dequantized tensor. X: Activation tensor of shape ``(T, K)``. Required. n_outer: Outer loop iterations (scale updates). Default 3. n_inner: Inner ADMM iterations per outer step. Default 10. Returns: ``(scales, quants)`` by default, or ``(scales, quants, dequant)`` if *return_dequant* is ``True``. - **scales**: Per-block FP8 E4M3 scales. Shape is *W.shape* with dimension *dim* removed. - **quants**: Signed FP4 codebook values. Same shape as *W*. - **dequant**: ``quants * scales`` broadcast. Same shape as *W*. """ assert X is not None, "X (activations) required for ADMM quantization" dim = dim % W.ndim block_size = W.shape[dim] assert block_size in (16, 32) x = W.float().movedim(dim, -1) batch_shape = x.shape[:-1] x = x.reshape(-1, block_size) # (N, bs) N = x.shape[0] # --- Block Hessians from X --- K_dim = X.shape[1] num_col_blocks = K_dim // block_size M_dim = N // num_col_blocks assert N == M_dim * num_col_blocks, ( f"Weight blocks ({N}) must equal M * K//bs ({M_dim} * {num_col_blocks})" ) # Compute H_j = X_j^T @ X_j in batches to avoid OOM on large X H = torch.empty(num_col_blocks, block_size, block_size, device=x.device) batch_t = 8192 for j in range(num_col_blocks): acc = torch.zeros(block_size, block_size, device=x.device) for t0 in range(0, X.shape[0], batch_t): Xj = X[t0 : t0 + batch_t, j * block_size : (j + 1) * block_size].float() acc.addmm_(Xj.T, Xj) H[j] = acc # --- ADMM constants --- # Normalize H by max eigenvalue so ADMM is well-conditioned eigvals = torch.linalg.eigvalsh(H) max_eig = eigvals[:, -1].clamp(min=1e-6) # (ncb,) H_norm = H / max_eig.reshape(-1, 1, 1) rho = torch.ones(num_col_blocks, device=x.device, dtype=x.dtype) eye = torch.eye(block_size, device=x.device, dtype=x.dtype).unsqueeze(0) M_inv = torch.linalg.inv(H_norm + rho.reshape(-1, 1, 1) * eye) # (ncb, bs, bs) # --- 3D layout: (M, ncb, bs) --- x_3d = x.reshape(M_dim, num_col_blocks, block_size) # --- Initialize scale via naive: s = FP8(max|x| / 6) --- amax = x.abs().amax(dim=-1) # (N,) s_cont = (amax / Q_MAX).clamp(min=1e-12) s0 = _fp8_e4m3_snap(s_cont) s = s0.reshape(M_dim, num_col_blocks) # --- Outer loop: alternate scale and quants --- ones = torch.ones((*x_3d.shape[:-1], 1), device=x.device, dtype=x.dtype) q_star = None for outer in range(n_outer): w = x_3d / s.unsqueeze(-1) # (M, ncb, bs) # z = FP4 projection of w (quantize with s=1) z = fp4_quantize(w, ones) u = torch.zeros_like(w) # Precompute H_norm @ w (use normalized H in ADMM) Hw = torch.einsum("jab,mjb->mja", H_norm, w) # Inner ADMM loop for k in range(n_inner): rhs = Hw + rho.reshape(1, -1, 1) * (z - u) q = torch.einsum("jab,mjb->mja", M_inv, rhs) z = fp4_quantize(q + u, ones) u = u + q - z q_star = z # discrete FP4 codebook values # --- Scale update: s = FP8(x^T H q* / (q*^T H q*)) --- Hq = torch.einsum("jab,mjb->mja", H, q_star) numer = (x_3d * Hq).sum(dim=-1) # (M, ncb) denom = (q_star * Hq).sum(dim=-1) # (M, ncb) valid = (numer > 0) & (denom > 1e-12) s_cont = (numer / denom.clamp(min=1e-12)).clamp(min=1e-12) s_snapped = _fp8_e4m3_snap(s_cont) s_new = torch.where(valid, s_snapped, s) if (s_new == s).all(): break s = s_new # --- Final output --- best_s = s.reshape(N) quants = q_star.reshape(N, block_size) result = ( best_s.reshape(*batch_shape), quants.reshape(*batch_shape, block_size).movedim(-1, dim), ) if return_dequant: dq = fp4_dequantize(quants, best_s.unsqueeze(-1)) result = result + (dq.reshape(*batch_shape, block_size).movedim(-1, dim),) return result
[docs] def nvfp4_dequantize(scales, quants, dim=-1): """Dequantize NVFP4: ``dequant = quants * scales``. Args: scales: Per-block scales. Shape is the original *W.shape* with dimension *dim* removed. quants: Signed FP4 codebook values. Same shape as the original *W*. dim: Block dimension in *quants* (default: -1). Returns: Dequantized tensor with the same shape as *quants*. """ return quants * scales.unsqueeze(dim)