| """
|
| Functional API for BitLinear operations.
|
|
|
| This module provides the core functional implementations that will be called
|
| by the nn.Module wrappers. These functions implement the mathematical operations
|
| described in the BitNet and ternary neural network papers.
|
| """
|
|
|
| import torch
|
| import torch.nn.functional as F
|
| from typing import Optional, Tuple
|
|
|
|
|
| def bitlinear_python(
|
| x: torch.Tensor,
|
| W: torch.Tensor,
|
| gamma: torch.Tensor,
|
| bias: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Pure PyTorch reference implementation of BitLinear forward pass.
|
|
|
| This implements the core BitLinear computation:
|
| output = x @ W^T * gamma + bias
|
|
|
| where W is a ternary weight matrix ({-1, 0, +1}), and gamma is a per-output
|
| scaling factor that compensates for the quantization.
|
|
|
| Args:
|
| x: Input tensor of shape [..., in_features]
|
| W: Ternary weight matrix of shape [out_features, in_features]
|
| with values in {-1, 0, +1}
|
| gamma: Scaling factors of shape [out_features] or [1, out_features]
|
| bias: Optional bias tensor of shape [out_features]
|
|
|
| Returns:
|
| Output tensor of shape [..., out_features]
|
|
|
| Notes:
|
| - This is the reference implementation for correctness
|
| - CUDA kernels will optimize the ternary matrix multiplication
|
| - Gamma scaling is applied per output channel
|
| """
|
|
|
|
|
| output = torch.matmul(x, W.t())
|
|
|
|
|
|
|
| if gamma.dim() == 1:
|
|
|
| output = output * gamma.unsqueeze(0)
|
| else:
|
|
|
| output = output * gamma
|
|
|
|
|
| if bias is not None:
|
| output = output + bias
|
|
|
| return output
|
|
|
|
|
| def greedy_ternary_decomposition(
|
| W: torch.Tensor,
|
| k: int,
|
| ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| """
|
| Greedy ternary decomposition of a weight matrix.
|
|
|
| Decomposes a dense weight matrix W into a sum of k ternary matrices:
|
| W ≈ sum_{i=1}^k gamma_i * W_i^ternary
|
|
|
| This follows the greedy residual minimization approach:
|
| 1. Quantize W to ternary → W_1, compute gamma_1
|
| 2. Compute residual R_1 = W - gamma_1 * W_1
|
| 3. Quantize R_1 to ternary → W_2, compute gamma_2
|
| 4. Repeat for k iterations
|
|
|
| Args:
|
| W: Dense weight matrix of shape [out_features, in_features]
|
| k: Number of ternary components (typically 2-4 for BitNet)
|
|
|
| Returns:
|
| W_ternary: Stacked ternary matrices of shape [k, out_features, in_features]
|
| gammas: Scaling factors of shape [k, out_features]
|
|
|
| Notes:
|
| - Each iteration reduces the residual error
|
| - Larger k provides better approximation but more computation
|
| - This is used in MultiTernaryLinear for improved expressiveness
|
|
|
| References:
|
| - BitNet paper: "BitNet: Scaling 1-bit Transformers for Large Language Models"
|
| - JMLR paper: https://jmlr.org/papers/volume26/24-2050/24-2050.pdf
|
| """
|
| from .quantization import weight_to_ternary
|
|
|
|
|
| residual = W.clone()
|
|
|
|
|
| ternary_weights = []
|
| gammas = []
|
|
|
|
|
| for i in range(k):
|
|
|
| W_t, gamma = weight_to_ternary(residual, per_channel=True)
|
|
|
|
|
| ternary_weights.append(W_t)
|
| gammas.append(gamma)
|
|
|
|
|
|
|
|
|
| residual = residual - (gamma.unsqueeze(1) * W_t)
|
|
|
|
|
| W_ternary = torch.stack(ternary_weights, dim=0)
|
| gammas_stacked = torch.stack(gammas, dim=0)
|
|
|
| return W_ternary, gammas_stacked
|
|
|
|
|
|
|
| def multi_ternary_linear_python(
|
| x: torch.Tensor,
|
| W_ternary: torch.Tensor,
|
| gammas: torch.Tensor,
|
| bias: Optional[torch.Tensor] = None,
|
| ) -> torch.Tensor:
|
| """
|
| Forward pass for multi-component ternary linear layer.
|
|
|
| Computes the sum of k ternary linear transformations:
|
| output = sum_{i=1}^k (x @ W_i^T * gamma_i) + bias
|
|
|
| Args:
|
| x: Input tensor of shape [..., in_features]
|
| W_ternary: Stacked ternary weights of shape [k, out_features, in_features]
|
| gammas: Scaling factors of shape [k, out_features]
|
| bias: Optional bias tensor of shape [out_features]
|
|
|
| Returns:
|
| Output tensor of shape [..., out_features]
|
| """
|
| k = W_ternary.size(0)
|
|
|
|
|
|
|
| output_shape = list(x.shape[:-1]) + [W_ternary.size(1)]
|
| output = torch.zeros(output_shape, dtype=x.dtype, device=x.device)
|
|
|
|
|
| for i in range(k):
|
|
|
| W_i = W_ternary[i]
|
| gamma_i = gammas[i]
|
|
|
|
|
| component_output = bitlinear_python(x, W_i, gamma_i, bias=None)
|
|
|
|
|
| output = output + component_output
|
|
|
|
|
| if bias is not None:
|
| output = output + bias
|
|
|
| return output
|
|
|
|
|
| def activation_quant(x: torch.Tensor, bits: int = 8) -> torch.Tensor:
|
| """
|
| Quantize activations for BitLinear.
|
|
|
| BitNet uses activation quantization in addition to weight quantization.
|
| This function implements per-token absmax quantization for activations.
|
|
|
| Args:
|
| x: Input activations of shape [..., features]
|
| bits: Number of bits for quantization (default: 8)
|
|
|
| Returns:
|
| Quantized activations (as float, not int)
|
|
|
| Notes:
|
| - Uses absmax scaling per token
|
| - Returns float tensor for compatibility with autograd
|
| - Simulates quantization effects without actual INT8 storage
|
| """
|
|
|
| Q_max = 2 ** (bits - 1) - 1
|
| Q_min = -Q_max
|
|
|
|
|
|
|
| scale = torch.max(torch.abs(x), dim=-1, keepdim=True)[0]
|
|
|
|
|
| scale = torch.clamp(scale, min=1e-5)
|
|
|
|
|
| x_normalized = x / scale
|
|
|
|
|
| x_quant_int = torch.clamp(
|
| torch.round(x_normalized * Q_max),
|
| min=Q_min,
|
| max=Q_max
|
| )
|
|
|
|
|
| x_quant = (x_quant_int / Q_max) * scale
|
|
|
| return x_quant
|
|
|