| |
|
| | from transformers import PreTrainedModel, PretrainedConfig |
| | from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP |
| | from transformers.generation import GenerationMixin |
| | from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
| | import torch |
| | import torch.nn as nn |
| |
|
| | class TRMConfig(PretrainedConfig): |
| | model_type = "recursive_gpt" |
| |
|
| | def __init__( |
| | self, |
| | vocab_size=50257, |
| | n_positions=1024, |
| | n_embd=512, |
| | n_physical_layers=3, |
| | n_loops=8, |
| | n_head=8, |
| | activation_function="gelu_new", |
| | resid_pdrop=0.1, |
| | embd_pdrop=0.1, |
| | attn_pdrop=0.1, |
| | layer_norm_epsilon=1e-5, |
| | scale_attn_weights=True, |
| | scale_attn_by_inverse_layer_idx=False, |
| | reorder_and_upcast_attn=False, |
| | **kwargs, |
| | ): |
| | super().__init__(**kwargs) |
| | self.vocab_size = vocab_size |
| | self.n_positions = n_positions |
| | self.n_embd = n_embd |
| | self.n_physical_layers = n_physical_layers |
| | self.n_loops = n_loops |
| | self.n_head = n_head |
| | self.activation_function = activation_function |
| | self.resid_pdrop = resid_pdrop |
| | self.embd_pdrop = embd_pdrop |
| | self.attn_pdrop = attn_pdrop |
| | self.layer_norm_epsilon = layer_norm_epsilon |
| | self.scale_attn_weights = scale_attn_weights |
| | self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
| | self.reorder_and_upcast_attn = reorder_and_upcast_attn |
| |
|
| | |
| | self.hidden_size = n_embd |
| | self.num_attention_heads = n_head |
| | self.num_hidden_layers = n_physical_layers |
| | self.n_inner = None |
| | self.is_encoder_decoder = False |
| |
|
| | class TinyRecursiveModel(PreTrainedModel, GenerationMixin): |
| | config_class = TRMConfig |
| | _tied_weights_keys = ["lm_head.weight"] |
| |
|
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.config = config |
| |
|
| | |
| | self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
| | self.wpe = nn.Embedding(config.n_positions, config.n_embd) |
| | self.drop = nn.Dropout(config.embd_pdrop) |
| |
|
| | |
| | self.physical_blocks = nn.ModuleList([ |
| | nn.ModuleDict({ |
| | "ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), |
| | "attn": GPT2Attention(config, layer_idx=i), |
| | "ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), |
| | "mlp": GPT2MLP(4 * config.n_embd, config) |
| | }) for i in range(config.n_physical_layers) |
| | ]) |
| |
|
| | |
| | self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
| |
|
| | |
| | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| | if input_ids is None: |
| | return None |
| |
|
| | batch_size, seq_len = input_ids.shape |
| | device = input_ids.device |
| |
|
| | |
| | token_embeds = self.wte(input_ids) |
| | pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) |
| | pos_embeds = self.wpe(pos_ids) |
| | hidden_states = self.drop(token_embeds + pos_embeds) |
| |
|
| | |
| | for loop in range(self.config.n_loops): |
| | block_idx = loop % self.config.n_physical_layers |
| | block = self.physical_blocks[block_idx] |
| |
|
| | |
| | ln_output = block["ln_1"](hidden_states) |
| | attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0] |
| | hidden_states = hidden_states + attn_output |
| |
|
| | |
| | ln_output = block["ln_2"](hidden_states) |
| | mlp_output = block["mlp"](ln_output) |
| | hidden_states = hidden_states + mlp_output |
| |
|
| | |
| | hidden_states = self.ln_f(hidden_states) |
| | logits = self.lm_head(hidden_states) |
| |
|
| | loss = None |
| | if labels is not None: |
| | shift_logits = logits[..., :-1, :].contiguous() |
| | shift_labels = labels[..., 1:].contiguous() |
| | loss_fct = nn.CrossEntropyLoss() |
| | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
| |
|
| | return CausalLMOutputWithCrossAttentions( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=hidden_states, |
| | attentions=None, |
| | cross_attentions=None |
| | ) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| | return {"input_ids": input_ids} |
| |
|
| | def _reorder_cache(self, past, beam_idx): |
| | return past |
| |
|