"""
MXFP4 quantization with UE8M0 scales -- Triton kernel with inline PTX ASM.
UE8M0 scales are pure powers of 2. The optimal scale is always either the
naive scale s0 or 2*s0 (one step up). So the "search" is just: compute
SSE for both, pick the winner.
Uses standard FP4 E2M1 codebook {0, 0.5, 1, 1.5, 2, 3, 4, 6}.
"""
import torch
import triton
import triton.language as tl
from ..fp4 import fp4_sse_block, fp4_dequant_block
Q_MAX = 6.0 # max codebook value
D_0 = 0.25 # decision boundary for zero
# ---------------------------------------------------------------------------
# Inline ASM helpers
# ---------------------------------------------------------------------------
@triton.jit
def ue8m0_snap_asm(val):
"""
Snap a positive float32 to UE8M0 scale: 2^(floor(log2(val)) - 2).
This computes the naive MXFP4 scale from amax using only bitwise ops.
The scale is 2^(e-127) where e = floor(log2(amax))-1+127.
PTX approach: extract the float32 exponent via bfe, subtract 2, reconstruct.
Result = 2^(biased_exp - 127 - 2) = 2^(biased_exp - 129).
Equivalently, set mantissa to 0 and exponent to (biased_exp - 2).
"""
result = tl.inline_asm_elementwise(
asm="""
{
.reg .b32 bits, exp32, new_exp, res;
.reg .pred p_uf;
mov.b32 bits, $1;
// Extract biased exponent (bits 23..30)
bfe.u32 exp32, bits, 23, 8;
// new_exp = exp32 - 2 (we want 2^(exp32-127-2) = float with exp = exp32-2)
add.s32 new_exp, exp32, -2;
// Underflow check: if new_exp < 1, clamp to 0 (smallest positive is 2^(1-127))
setp.lt.s32 p_uf, new_exp, 1;
// Reconstruct: new_exp << 23, mantissa = 0
shl.b32 res, new_exp, 23;
// Underflow -> return smallest UE8M0 scale: 2^(1-127) = 2^(-126)
// which is 0x00800000 as float32 (biased_exp=1, mant=0)
@p_uf mov.b32 res, 0x00800000;
mov.b32 $0, res;
}
""",
constraints="=r,r",
args=[val],
dtype=tl.float32,
is_pure=True,
pack=1,
)
return result
# ---------------------------------------------------------------------------
# Naive kernel
# ---------------------------------------------------------------------------
@triton.jit
def mxfp4_naive_kernel(
x_ptr,
out_ptr,
out_scale_ptr,
total_blocks,
block_stride,
element_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
if pid >= total_blocks:
return
offs = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + pid * block_stride + offs * element_stride).to(tl.float32)
x_abs = tl.abs(x)
amax = tl.max(x_abs, axis=0)
# UE8M0 snap: s = 2^(floor(log2(amax)) - 2) via inline ASM
amax_safe = tl.maximum(amax, 1e-30)
amax_vec = amax_safe + tl.zeros([1], dtype=tl.float32)
s0_vec = ue8m0_snap_asm(amax_vec)
s0 = tl.sum(s0_vec, axis=0)
dq = fp4_dequant_block(x, s0, BLOCK_SIZE)
tl.store(out_ptr + pid * BLOCK_SIZE + offs, dq)
tl.store(out_scale_ptr + pid, s0)
# ---------------------------------------------------------------------------
# Optimal kernel -- just compare s0 vs 2*s0
# ---------------------------------------------------------------------------
@triton.jit
def mxfp4_optimal_kernel(
x_ptr,
out_ptr,
out_scale_ptr,
total_blocks,
block_stride,
element_stride,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
if pid >= total_blocks:
return
offs = tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + pid * block_stride + offs * element_stride).to(tl.float32)
x_abs = tl.abs(x)
amax = tl.max(x_abs, axis=0)
# Naive scale via inline ASM
amax_safe = tl.maximum(amax, 1e-30)
amax_vec = amax_safe + tl.zeros([1], dtype=tl.float32)
s0_vec = ue8m0_snap_asm(amax_vec)
s0 = tl.sum(s0_vec, axis=0)
s0 = tl.maximum(s0, 1e-30)
# Compute SSE for s0
E0 = fp4_sse_block(x, x_abs, s0, BLOCK_SIZE)
# Candidate: s1 = 2 * s0 (one UE8M0 step up)
s1 = s0 * 2.0
E1 = fp4_sse_block(x, x_abs, s1, BLOCK_SIZE)
# Pick the better scale
best_s = s0
if E1 < E0:
best_s = s1
dq = fp4_dequant_block(x, best_s, BLOCK_SIZE)
tl.store(out_ptr + pid * BLOCK_SIZE + offs, dq)
tl.store(out_scale_ptr + pid, best_s)
# ---------------------------------------------------------------------------
# Python wrappers
# ---------------------------------------------------------------------------
[docs]
def mxfp4_naive_triton(W, dim=-1, return_dequant=False):
"""Naive MXFP4 quantization using Triton kernel with inline PTX ASM.
GPU-accelerated version of :func:`~qwantize.mxfp4.reference.mxfp4_naive`.
Args:
W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size).
dim: Dimension along which to quantize (default: -1).
return_dequant: If ``True``, also return the dequantized tensor.
Returns:
``(scales, quants)`` by default, or ``(scales, quants, dequant)``
if *return_dequant* is ``True``. See
:func:`~qwantize.mxfp4.reference.mxfp4_naive` for shape details.
"""
dim = dim % W.ndim
block_size = W.shape[dim]
assert block_size in (16, 32)
x = W.float()
batch_shape = x.shape[:dim] + x.shape[dim + 1:]
num_blocks = 1
for s in batch_shape:
num_blocks *= s
x_contig = x.movedim(dim, -1).reshape(-1, block_size)
block_stride = x_contig.stride(0)
element_stride = x_contig.stride(1)
out = torch.empty(num_blocks * block_size, device=W.device, dtype=torch.float32)
out_scales = torch.empty(num_blocks, device=W.device, dtype=torch.float32)
grid = (num_blocks,)
mxfp4_naive_kernel[grid](
x_contig, out, out_scales, num_blocks,
block_stride, element_stride, BLOCK_SIZE=block_size,
)
quants = out.reshape(-1, block_size) / out_scales.unsqueeze(-1)
result = (
out_scales.reshape(*batch_shape),
quants.reshape(*batch_shape, block_size).movedim(-1, dim),
)
if return_dequant:
result = result + (out.reshape(*batch_shape, block_size).movedim(-1, dim),)
return result
[docs]
def mxfp4_optimal_triton(W, dim=-1, return_dequant=False):
"""Optimal MXFP4 quantization using Triton kernel with inline PTX ASM.
GPU-accelerated version of :func:`~qwantize.mxfp4.reference.mxfp4_optimal`.
Compares naive scale ``s0`` with ``2*s0`` (one UE8M0 step up) and picks
whichever has lower SSE.
Args:
W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size).
dim: Dimension along which to quantize (default: -1).
return_dequant: If ``True``, also return the dequantized tensor.
Returns:
``(scales, quants)`` by default, or ``(scales, quants, dequant)``
if *return_dequant* is ``True``. See
:func:`~qwantize.mxfp4.reference.mxfp4_optimal` for shape details.
"""
dim = dim % W.ndim
block_size = W.shape[dim]
assert block_size in (16, 32)
x = W.float()
batch_shape = x.shape[:dim] + x.shape[dim + 1:]
num_blocks = 1
for s in batch_shape:
num_blocks *= s
x_contig = x.movedim(dim, -1).reshape(-1, block_size)
block_stride = x_contig.stride(0)
element_stride = x_contig.stride(1)
out = torch.empty(num_blocks * block_size, device=W.device, dtype=torch.float32)
out_scales = torch.empty(num_blocks, device=W.device, dtype=torch.float32)
grid = (num_blocks,)
mxfp4_optimal_kernel[grid](
x_contig, out, out_scales, num_blocks,
block_stride, element_stride, BLOCK_SIZE=block_size,
)
quants = out.reshape(-1, block_size) / out_scales.unsqueeze(-1)
result = (
out_scales.reshape(*batch_shape),
quants.reshape(*batch_shape, block_size).movedim(-1, dim),
)
if return_dequant:
result = result + (out.reshape(*batch_shape, block_size).movedim(-1, dim),)
return result
# ---------------------------------------------------------------------------
# Torch reference implementations (for benchmarking)
# ---------------------------------------------------------------------------
[docs]
def mxfp4_naive_torch(W, dim=-1, return_dequant=False):
"""Naive MXFP4 quantization using pure PyTorch operations.
Functionally identical to :func:`~qwantize.mxfp4.reference.mxfp4_naive`
but uses vectorized ``argmin`` over the full signed codebook instead of
``bucketize``.
Args:
W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size).
dim: Dimension along which to quantize (default: -1).
return_dequant: If ``True``, also return the dequantized tensor.
Returns:
``(scales, quants)`` by default, or ``(scales, quants, dequant)``
if *return_dequant* is ``True``. See
:func:`~qwantize.mxfp4.reference.mxfp4_naive` for shape details.
"""
dim = dim % W.ndim
block_size = W.shape[dim]
assert block_size in (16, 32)
codebook = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
device=W.device, dtype=torch.float32,
)
x = W.float().movedim(dim, -1)
batch_shape = x.shape[:-1]
x = x.reshape(-1, block_size) # (N, block_size)
# UE8M0 scale
scale_exponent = x.abs().amax(dim=-1).clamp(min=1e-30).log2().add(-2 + 127).floor()
scale_exponent = scale_exponent.clamp(min=1, max=254)
scale = torch.pow(2.0, scale_exponent - 127.0) # (N,)
# Quantize: find closest codebook value
possible = (scale.unsqueeze(-1) * codebook.view(1, 16)).unsqueeze(-2) # (N, 1, 16)
deltas = x.unsqueeze(-1) - possible # (N, bs, 16)
quants = codebook[deltas.abs().argmin(dim=-1)] # (N, bs)
result = (
scale.reshape(*batch_shape),
quants.reshape(*batch_shape, block_size).movedim(-1, dim),
)
if return_dequant:
dequants = scale.unsqueeze(-1) * quants
result = result + (dequants.reshape(*batch_shape, block_size).movedim(-1, dim),)
return result
[docs]
def mxfp4_optimal_torch(W, dim=-1, return_dequant=False):
"""Optimal MXFP4 quantization using pure PyTorch operations.
Tries naive scale ``s0`` and ``2*s0``, picks whichever has lower SSE.
Args:
W: Input tensor. ``W.shape[dim]`` must be 16 or 32 (the block size).
dim: Dimension along which to quantize (default: -1).
return_dequant: If ``True``, also return the dequantized tensor.
Returns:
``(scales, quants)`` by default, or ``(scales, quants, dequant)``
if *return_dequant* is ``True``. See
:func:`~qwantize.mxfp4.reference.mxfp4_optimal` for shape details.
"""
dim = dim % W.ndim
block_size = W.shape[dim]
assert block_size in (16, 32)
codebook = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0],
device=W.device, dtype=torch.float32,
)
x = W.float().movedim(dim, -1)
batch_shape = x.shape[:-1]
x = x.reshape(-1, block_size) # (N, block_size)
# s0: naive UE8M0 scale
scale_exponent = x.abs().amax(dim=-1).clamp(min=1e-30).log2().add(-2 + 127).floor()
scale_exponent = scale_exponent.clamp(min=1, max=254)
s0 = torch.pow(2.0, scale_exponent - 127.0) # (N,)
# s1: one step up
s1_exponent = (scale_exponent + 1).clamp(max=254)
s1 = torch.pow(2.0, s1_exponent - 127.0)
def quant_dequant_sse(x, scale):
possible = (scale.unsqueeze(-1) * codebook.view(1, 16)).unsqueeze(-2) # (N, 1, 16)
deltas = x.unsqueeze(-1) - possible # (N, bs, 16)
q = codebook[deltas.abs().argmin(dim=-1)]
dequants = scale.unsqueeze(-1) * q
sse = (x - dequants).pow(2).sum(dim=-1)
return dequants, q, sse
dq0, q0, sse0 = quant_dequant_sse(x, s0)
dq1, q1, sse1 = quant_dequant_sse(x, s1)
# Pick best per block
use_s1 = (sse1 < sse0).unsqueeze(-1)
quants = torch.where(use_s1, q1, q0)
best_scale = torch.where(sse1 < sse0, s1, s0)
result = (
best_scale.reshape(*batch_shape),
quants.reshape(*batch_shape, block_size).movedim(-1, dim),
)
if return_dequant:
dequants = torch.where(use_s1, dq1, dq0)
result = result + (dequants.reshape(*batch_shape, block_size).movedim(-1, dim),)
return result