Source code for qwantize.nvfp4.kernels

import torch
import triton
import triton.language as tl

from .reference import build_fp8_e4m3_scales
from ..fp4 import fp4_sse_block, fp4_hessian_block, fp4_dequant_block
from ..fp8 import fp8_e4m3_snap_asm

Q_MAX = 6.0
D_0 = 0.25


# ---------------------------------------------------------------------------
# Naive kernel
# ---------------------------------------------------------------------------


@triton.jit
def nvfp4_naive_kernel(
    x_ptr,
    out_ptr,
    out_scale_ptr,
    total_blocks,
    block_stride,
    element_stride,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(0)
    if pid >= total_blocks:
        return

    offs = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + pid * block_stride + offs * element_stride).to(tl.float32)
    x_abs = tl.abs(x)
    amax = tl.max(x_abs, axis=0)

    # Snap amax/6 to FP8 E4M3 using inline ASM bitwise ops
    s_cont = amax / 6.0
    s_cont = tl.maximum(s_cont, 1e-12)
    # fp8_e4m3_snap_asm expects a tensor, broadcast scalar to size-1
    s_cont_vec = s_cont + tl.zeros([1], dtype=tl.float32)
    s0_vec = fp8_e4m3_snap_asm(s_cont_vec)
    s0 = tl.sum(s0_vec, axis=0)  # extract scalar

    # Dequantize with naive scale
    dq = fp4_dequant_block(x, s0, BLOCK_SIZE)

    tl.store(out_ptr + pid * BLOCK_SIZE + offs, dq)
    tl.store(out_scale_ptr + pid, s0)


# ---------------------------------------------------------------------------
# Optimal kernel
# ---------------------------------------------------------------------------


@triton.jit
def nvfp4_optimal_kernel(
    x_ptr,
    scale_table_ptr,
    out_ptr,
    out_scale_ptr,
    total_blocks,
    block_stride,
    element_stride,
    BLOCK_SIZE: tl.constexpr,
    NUM_SCALES: tl.constexpr,
):
    pid = tl.program_id(0)
    if pid >= total_blocks:
        return

    # Step 1: Load block
    offs = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + pid * block_stride + offs * element_stride).to(tl.float32)
    x_abs = tl.abs(x)
    amax = tl.max(x_abs, axis=0)

    # Step 2: Baseline scale via inline ASM FP8 snap
    s_cont = amax / 6.0
    s_cont = tl.maximum(s_cont, 1e-12)
    s_cont_vec = s_cont + tl.zeros([1], dtype=tl.float32)
    s0_vec = fp8_e4m3_snap_asm(s_cont_vec)
    s0 = tl.sum(s0_vec, axis=0)
    # Ensure s0 > 0
    s0 = tl.maximum(s0, 1e-12)

    # Step 3: Compute baseline error E0
    E0 = fp4_sse_block(x, x_abs, s0, BLOCK_SIZE)

    best_s = s0
    best_E = E0

    # Step 4: Edge case -- noise block
    total_energy = tl.sum(x * x, axis=0)
    is_noise = total_energy <= E0

    if is_noise == 0:
        # Step 5: Lower bound
        sqrt_E0 = tl.sqrt(E0)
        s_min = tl.maximum(0.0, (amax - sqrt_E0) / 6.0)

        # Step 6: Upper bound -- sort, cumsum, find k*
        sorted_abs = tl.sort(x_abs)
        sorted_sq = sorted_abs * sorted_abs
        cumsum_sq = tl.cumsum(sorted_sq, axis=0)

        # k_star = number of elements whose cumsum of squares <= E0
        k_mask = (cumsum_sq <= E0).to(tl.int32)
        k_star = tl.sum(k_mask, axis=0)

        # Extract sorted_abs[min(k_star, BLOCK_SIZE-1)]
        k_idx = tl.minimum(k_star, BLOCK_SIZE - 1)
        # Use masked reduction to pick element at index k_idx
        y_k = tl.sum(tl.where(offs == k_idx, sorted_abs, 0.0), axis=0)
        s_max = y_k / 0.25

        # Step 7: Bounded search over FP8 E4M3 scale candidates
        for i in range(NUM_SCALES):
            s_cand = tl.load(scale_table_ptr + i)

            in_range = (s_cand >= s_min) & (s_cand <= s_max)
            if in_range:
                # Fast-fail: clipping error H(s)
                clip_excess = tl.maximum(x_abs - 6.0 * s_cand, 0.0)
                H_s = tl.sum(clip_excess * clip_excess, axis=0)

                if H_s < best_E:
                    # Full SSE via inline ASM FP4 dequant
                    E_s = fp4_sse_block(x, x_abs, s_cand, BLOCK_SIZE)
                    if E_s < best_E:
                        best_E = E_s
                        best_s = s_cand

    # Step 8: Final dequantization with optimal scale
    dq = fp4_dequant_block(x, best_s, BLOCK_SIZE)

    tl.store(out_ptr + pid * BLOCK_SIZE + offs, dq)
    tl.store(out_scale_ptr + pid, best_s)


# ---------------------------------------------------------------------------
# Hessian-aware optimal kernel
# ---------------------------------------------------------------------------


@triton.jit
def nvfp4_optimal_hessian_kernel(
    x_ptr,
    scale_table_ptr,
    H_ptr,
    out_ptr,
    out_scale_ptr,
    total_blocks,
    num_col_blocks,
    block_stride,
    element_stride,
    BLOCK_SIZE: tl.constexpr,
    NUM_SCALES: tl.constexpr,
):
    pid = tl.program_id(0)
    if pid >= total_blocks:
        return

    # Column-block index for Hessian lookup
    j = pid % num_col_blocks

    # Step 1: Load block
    offs = tl.arange(0, BLOCK_SIZE)
    x = tl.load(x_ptr + pid * block_stride + offs * element_stride).to(tl.float32)
    x_abs = tl.abs(x)
    amax = tl.max(x_abs, axis=0)

    # Load Hessian block H[j]: (BLOCK_SIZE, BLOCK_SIZE)
    row_offs = tl.arange(0, BLOCK_SIZE)
    col_offs = tl.arange(0, BLOCK_SIZE)
    H_base = j * BLOCK_SIZE * BLOCK_SIZE
    H_block = tl.load(
        H_ptr + H_base + row_offs[:, None] * BLOCK_SIZE + col_offs[None, :]
    )

    # Step 2: Baseline scale via inline ASM FP8 snap
    s_cont = amax / 6.0
    s_cont = tl.maximum(s_cont, 1e-12)
    s_cont_vec = s_cont + tl.zeros([1], dtype=tl.float32)
    s0_vec = fp8_e4m3_snap_asm(s_cont_vec)
    s0 = tl.sum(s0_vec, axis=0)
    s0 = tl.maximum(s0, 1e-12)

    # Step 3: Baseline errors -- SSE for bounding, Hessian for selection
    E0_sse = fp4_sse_block(x, x_abs, s0, BLOCK_SIZE)
    E0_H = fp4_hessian_block(x, x_abs, s0, H_block, BLOCK_SIZE)

    best_s = s0
    best_E = E0_H

    # Step 4: Edge case -- noise block
    total_energy = tl.sum(x * x, axis=0)
    is_noise = total_energy <= E0_sse

    if is_noise == 0:
        # Step 5: Lower bound (SSE-based, still valid for pruning)
        sqrt_E0 = tl.sqrt(E0_sse)
        s_min = tl.maximum(0.0, (amax - sqrt_E0) / 6.0)

        # Step 6: Upper bound -- sort, cumsum, find k*
        sorted_abs = tl.sort(x_abs)
        sorted_sq = sorted_abs * sorted_abs
        cumsum_sq = tl.cumsum(sorted_sq, axis=0)

        k_mask = (cumsum_sq <= E0_sse).to(tl.int32)
        k_star = tl.sum(k_mask, axis=0)

        k_idx = tl.minimum(k_star, BLOCK_SIZE - 1)
        y_k = tl.sum(tl.where(offs == k_idx, sorted_abs, 0.0), axis=0)
        s_max = y_k / 0.25

        # Step 7: Bounded search -- evaluate Hessian error
        for i in range(NUM_SCALES):
            s_cand = tl.load(scale_table_ptr + i)

            in_range = (s_cand >= s_min) & (s_cand <= s_max)
            if in_range:
                # Fast-fail: SSE clipping error (cheap necessary condition)
                clip_excess = tl.maximum(x_abs - 6.0 * s_cand, 0.0)
                H_s = tl.sum(clip_excess * clip_excess, axis=0)

                # Only skip if clipping error alone exceeds baseline SSE
                if H_s < E0_sse:
                    # Full Hessian-weighted error
                    E_H = fp4_hessian_block(x, x_abs, s_cand, H_block, BLOCK_SIZE)
                    if E_H < best_E:
                        best_E = E_H
                        best_s = s_cand

    # Step 8: Final dequantization with optimal scale
    dq = fp4_dequant_block(x, best_s, BLOCK_SIZE)

    tl.store(out_ptr + pid * BLOCK_SIZE + offs, dq)
    tl.store(out_scale_ptr + pid, best_s)


# ---------------------------------------------------------------------------
# Python wrappers
# ---------------------------------------------------------------------------


[docs] def nvfp4_naive_triton(W, dim=-1, return_dequant=False): """Naive NVFP4 quantization using Triton kernel with inline PTX ASM. GPU-accelerated version of :func:`~qwantize.nvfp4.reference.nvfp4_naive`. Args: W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size). dim: Dimension along which to quantize (default: -1). return_dequant: If ``True``, also return the dequantized tensor. Returns: ``(scales, quants)`` by default, or ``(scales, quants, dequant)`` if *return_dequant* is ``True``. See :func:`~qwantize.nvfp4.reference.nvfp4_naive` for shape details. """ dim = dim % W.ndim block_size = W.shape[dim] assert block_size in (16, 32) x = W.float() batch_shape = x.shape[:dim] + x.shape[dim + 1:] num_blocks = 1 for s in batch_shape: num_blocks *= s x_contig = x.movedim(dim, -1).reshape(-1, block_size) block_stride = x_contig.stride(0) element_stride = x_contig.stride(1) out = torch.empty(num_blocks * block_size, device=W.device, dtype=torch.float32) out_scales = torch.empty(num_blocks, device=W.device, dtype=torch.float32) grid = (num_blocks,) nvfp4_naive_kernel[grid]( x_contig, out, out_scales, num_blocks, block_stride, element_stride, BLOCK_SIZE=block_size, ) quants = out.reshape(-1, block_size) / out_scales.unsqueeze(-1) result = ( out_scales.reshape(*batch_shape), quants.reshape(*batch_shape, block_size).movedim(-1, dim), ) if return_dequant: result = result + (out.reshape(*batch_shape, block_size).movedim(-1, dim),) return result
[docs] def nvfp4_optimal_triton(W, dim=-1, return_dequant=False): """Optimal NVFP4 quantization using Triton kernel with inline PTX ASM. GPU-accelerated version of :func:`~qwantize.nvfp4.reference.nvfp4_optimal`. Args: W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size). dim: Dimension along which to quantize (default: -1). return_dequant: If ``True``, also return the dequantized tensor. Returns: ``(scales, quants)`` by default, or ``(scales, quants, dequant)`` if *return_dequant* is ``True``. See :func:`~qwantize.nvfp4.reference.nvfp4_optimal` for shape details. """ dim = dim % W.ndim block_size = W.shape[dim] assert block_size in (16, 32) x = W.float() batch_shape = x.shape[:dim] + x.shape[dim + 1:] num_blocks = 1 for s in batch_shape: num_blocks *= s x_contig = x.movedim(dim, -1).reshape(-1, block_size) block_stride = x_contig.stride(0) element_stride = x_contig.stride(1) scale_table = build_fp8_e4m3_scales(device=W.device) num_scales = scale_table.shape[0] out = torch.empty(num_blocks * block_size, device=W.device, dtype=torch.float32) out_scales = torch.empty(num_blocks, device=W.device, dtype=torch.float32) grid = (num_blocks,) nvfp4_optimal_kernel[grid]( x_contig, scale_table, out, out_scales, num_blocks, block_stride, element_stride, BLOCK_SIZE=block_size, NUM_SCALES=num_scales, ) quants = out.reshape(-1, block_size) / out_scales.unsqueeze(-1) result = ( out_scales.reshape(*batch_shape), quants.reshape(*batch_shape, block_size).movedim(-1, dim), ) if return_dequant: result = result + (out.reshape(*batch_shape, block_size).movedim(-1, dim),) return result
[docs] def nvfp4_optimal_hessian_triton(W, dim=-1, return_dequant=False, X=None): """Hessian-aware optimal NVFP4 quantization using Triton kernel. GPU-accelerated version of :func:`~qwantize.nvfp4.reference.nvfp4_optimal_hessian`. Args: W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size). dim: Dimension along which to quantize (default: -1). return_dequant: If ``True``, also return the dequantized tensor. X: Activation tensor of shape ``(T, K)``. Required. Returns: ``(scales, quants)`` by default, or ``(scales, quants, dequant)`` if *return_dequant* is ``True``. See :func:`~qwantize.nvfp4.reference.nvfp4_optimal_hessian` for shape details. """ assert X is not None, "X (activations) required for Hessian-aware scale search" dim = dim % W.ndim block_size = W.shape[dim] assert block_size in (16, 32) x = W.float() batch_shape = x.shape[:dim] + x.shape[dim + 1:] num_blocks = 1 for s in batch_shape: num_blocks *= s x_contig = x.movedim(dim, -1).reshape(-1, block_size) block_stride = x_contig.stride(0) element_stride = x_contig.stride(1) # Compute block Hessians: H[j] = X_j^T @ X_j K_dim = X.shape[1] num_col_blocks = K_dim // block_size H = torch.empty(num_col_blocks, block_size, block_size, device=W.device, dtype=torch.float32) batch_t = 8192 for j in range(num_col_blocks): acc = torch.zeros(block_size, block_size, device=W.device, dtype=torch.float32) for t0 in range(0, X.shape[0], batch_t): Xj = X[t0 : t0 + batch_t, j * block_size : (j + 1) * block_size].float() acc.addmm_(Xj.T, Xj) H[j] = acc H = H.contiguous() scale_table = build_fp8_e4m3_scales(device=W.device) num_scales = scale_table.shape[0] out = torch.empty(num_blocks * block_size, device=W.device, dtype=torch.float32) out_scales = torch.empty(num_blocks, device=W.device, dtype=torch.float32) grid = (num_blocks,) nvfp4_optimal_hessian_kernel[grid]( x_contig, scale_table, H, out, out_scales, num_blocks, num_col_blocks, block_stride, element_stride, BLOCK_SIZE=block_size, NUM_SCALES=num_scales, ) quants = out.reshape(-1, block_size) / out_scales.unsqueeze(-1) result = ( out_scales.reshape(*batch_shape), quants.reshape(*batch_shape, block_size).movedim(-1, dim), ) if return_dequant: result = result + (out.reshape(*batch_shape, block_size).movedim(-1, dim),) return result