Transformers documentation

Subclassing Trainer methods

You are viewing main version, which requires installation from source. If you'd like regular pip install, checkout the latest stable version (v5.2.0).
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Subclassing Trainer methods

Subclass Trainer methods to change training behavior without rewriting the entire loop. Subclassing modifies the training loop, for example the forward pass or loss computation.

Before subclassing, consider whether you need to change what Trainer computes or when and whether it acts. For timing and conditional logic, use a Callback instead. Callbacks control when things happen (logging, evaluation, early stopping) and subclassing changes what happens (loss computation, data loading, optimization).

See the Trainer API docs for a complete list of methods you can subclass. Private methods (prefixed with _) like _save_checkpoint or _evaluate can also be overridden, but these may change without notice.

get_train_dataloader

The standard get_train_dataloader() method loads one batch, trains on it, discards it, and loads the next batch.

def get_train_dataloader(self):
    return self._get_dataloader(
        batch_size=self._train_batch_size,
        ...
)

GRPO is an online reinforcement learning algorithm that generates completions before training on them. Generating completions every step is expensive because it’s autoregressive. A 512-token completion requires ~512 sequential forward passes compared to one forward pass for a training step. GRPOTrainer subclasses get_train_dataloader() to batch generation across multiple steps.

trl.GRPOTrainer.get_train_dataloader loads batches of generation prompts for multiple training steps at once by multiplying batch size by a steps_per_generation argument. If train_batch_size=4 and steps_per_generation=8, the dataloader produces batches of 32, cutting generation cost by 8x.

def get_train_dataloader(self):
    dataloader_params = {
        "batch_size": self._train_batch_size * self.args.steps_per_generation, # this is the only change
        ...
    }

compute_loss

compute_loss() returns the cross-entropy loss calculated by the model.

def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
    ...
    outputs = model(**inputs)
    ...
    loss = outputs["loss"] # get loss from model

    return (loss, outputs) if return_outputs else loss

DPO measures how strongly the policy model prefers a chosen response over a rejected one, relative to a reference model. DPOTrainer subclasses compute_loss() because the loss computation differs from standard cross-entropy in several ways:

  • the model never sees labels; it only returns logits for DPO to calculate log-probs from
  • chosen and rejected responses are concatenated
  • a reference model calculates its own log-probs
  • the loss is a function of π_chosen, π_rejected, π_ref_chosen, π_ref_rejected

None of the above fits the standard Trainer.compute_loss() method.

def compute_loss(
    self,
    model: PreTrainedModel | nn.Module,
    inputs: dict[str, torch.Tensor | Any],
    return_outputs=False,
    num_items_in_batch=None,
) -> torch.Tensor | tuple[torch.Tensor, dict[str, float]]:
    ...
    outputs = model(**inputs)
    logits = outputs.logits
    logps = get_logps(logits, inputs)
    chosen_logps, rejected_logps = logps.chunk(2, dim=0)  # batch is [chosen, rejected]
    ref_logits = self.ref_model(**inputs).logits
    ref_logps = get_logps(ref_logits, inputs)
    ref_chosen_logps, ref_rejected_logps = ref_logps.chunk(2, dim=0)  # batch is [chosen, rejected]
    chosen_scores = chosen_logps - ref_chosen_logps
    rejected_scores = rejected_logps - ref_rejected_logps
    per_sequence_loss = -F.logsigmoid(self.beta * chosen_scores - rejected_scores)
    loss = per_sequence_loss.mean()
    return (loss, outputs) if return_outputs else loss

Next steps

  • For more real-world examples, see how GRPOTrainer and DPOTrainer extend Trainer in TRL, or how Axolotl builds custom trainers on top of it.
  • Check the Callbacks guide if you only need to customize what happens during a training event such as logging metrics at the end of a training step.
Update on GitHub