Transformers documentation

Writing kernels

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v5.12.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Writing kernels

This guide explains how to write kernels that go beyond a stateless forward replacement. It covers two capabilities the extended KernelConfig API supports:

  1. Parameter transformation: the kernel expects weights in a different layout than the original model (for example, renamed or merged parameters).
  2. Module fusion: the kernel replaces multiple adjacent modules with a single fused implementation.

For basic kernels (stateless forward replacements with no parameter changes), see the kernels library documentation.

Two-class pattern

Any kernel that carries its own parameters follows a two-class pattern.

  • KernelName: contains only the forward pass. The kernels library uses this class to kernelize the model because it does not allow stateful kernel classes.
  • KernelNameLayout: an nn.Module that holds the parameters and monkey-patches the original module before the checkpoint is loaded. At runtime, kernelize replaces its forward with the forward from KernelName’. You do not need to define forward. Transformers injects one automatically with the same signature as KernelName.forward.

The naming convention is strict. The layout class must be named {KernelName}Layout and defined in the same module as KernelName.

Parameter transformation

Use this pattern when the kernel expects weights under different names or in a different shape than the original model checkpoint.

The KernelNameLayout class has the same __init__ signature as the module it replaces and declares a conversion_mapping class attribute that tells Transformers how to remap checkpoint keys to the new parameter names (see Dynamic weight loading for more details).

import torch
import torch.nn as nn

class CustomRMSNormLayout(nn.Module):
    conversion_mapping = [...]  # rules that remap checkpoint keys to the new parameter names

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps


class CustomRMSNorm(nn.Module):
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.scale * hidden_states.to(input_dtype)


class layers:
    CustomRMSNorm = CustomRMSNorm

The layers class is required by the kernels library to expose the kernel entry point.

Load this kernel by passing the repo and class name to KernelConfig. The key is the original module class name from the model. The value points to the KernelName class (not the Layout) in the repo.

from transformers import AutoModelForCausalLM, KernelConfig

kernel_config = KernelConfig({"RMSNorm": "owner/my-kernel:CustomRMSNorm"})
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    use_kernels=True,
    kernel_config=kernel_config,
    device_map="cuda",
)

When the model loads, Transformers:

  1. Loads CustomRMSNorm from the repo and looks for CustomRMSNormLayout in the same module.
  2. Monkey-patches every RMSNorm in the model with CustomRMSNormLayout.
  3. Remaps checkpoint weights using conversion_mapping so they load into the new parameter names.
  4. Calls kernelize, which replaces CustomRMSNormLayout.forward with CustomRMSNorm.forward.

Module fusion

Use this pattern when a kernel replaces multiple adjacent modules with a single fused implementation. Because the fused module combines parameters from several original modules, the KernelNameLayout.__init__ receives the instantiated child modules rather than their constructor arguments.

import torch
import torch.nn as nn

class RMSNormMLPLayout(nn.Module):
    conversion_mapping = [...]  # rules that remap checkpoint keys to the fused parameter names

    def __init__(self, norm, mlp):
        super().__init__()
        self.variance_epsilon = norm.variance_epsilon
        self.scale = nn.Parameter(torch.empty_like(norm.weight))
        self.gate_up_proj = nn.Linear(
            mlp.gate_proj.in_features,
            mlp.gate_proj.out_features + mlp.up_proj.out_features,
            bias=mlp.gate_proj.bias is not None,
            device=mlp.gate_proj.weight.device,
            dtype=mlp.gate_proj.weight.dtype,
        )
        self.down_proj = nn.Linear(
            mlp.down_proj.in_features,
            mlp.down_proj.out_features,
            bias=mlp.down_proj.bias is not None,
            device=mlp.down_proj.weight.device,
            dtype=mlp.down_proj.weight.dtype,
        )
        self.act_fn = mlp.act_fn


class RMSNormMLP(nn.Module):
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        hidden_states = self.scale * hidden_states.to(input_dtype)
        gate, up = self.gate_up_proj(hidden_states).chunk(2, dim=-1)
        return self.down_proj(self.act_fn(gate) * up)


class layers:
    RMSNormMLP = RMSNormMLP

To fuse modules, pass a tuple of (class_name, path_pattern) pairs as the key in KernelConfig instead of a plain string. All patterns must share the same parent module (Transformers fuses the children in that parent). The * wildcard matches any single path segment.

from transformers import AutoModelForCausalLM, KernelConfig

kernel_config = KernelConfig(
    {
        (
            ("RMSNorm", "model.layers.*.post_attention_layernorm"),
            ("MLP",     "model.layers.*.mlp"),
        ): "owner/my-kernel:RMSNormMLP",
    }
)
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-0.6B",
    use_kernels=True,
    kernel_config=kernel_config,
    device_map="cuda",
)

When the model loads, Transformers:

  1. Loads RMSNormMLP from the repo and finds RMSNormMLPLayout in the same module.
  2. Matches every decoder layer at model.layers.* and builds a fused parent class whose __init__ calls RMSNormMLPLayout(post_attention_layernorm, mlp).
  3. Replaces the remaining child (mlp) with nn.Identity() to preserve the parent module’s interface.
  4. Remaps checkpoint weights using conversion_mapping.
  5. Calls kernelize, which replaces RMSNormMLPLayout.forward with RMSNormMLP.forward.

The order of pairs in the fusion tuple determines the argument order passed to KernelNameLayout.__init__.

Update on GitHub