| | import torch |
| | import torch.nn as nn |
| | from transformers import PretrainedConfig, PreTrainedModel |
| | from diffusionLM.model.diffusionLM import LLaDAModel |
| |
|
| | class DiffusionConfig(PretrainedConfig): |
| | """Configuration class for Diffusion-LLM model.""" |
| | model_type = "diffusionLM" |
| | |
| | def __init__( |
| | self, |
| | vocab_size: int = 50257, |
| | hidden_size: int = 768, |
| | num_hidden_layers: int = 12, |
| | num_attention_heads: int = 12, |
| | intermediate_size: int = 3072, |
| | hidden_dropout_prob: float = 0.1, |
| | attention_probs_dropout_prob: float = 0.1, |
| | max_position_embeddings: int = 1024, |
| | initializer_range: float = 0.02, |
| | layer_norm_eps: float = 1e-12, |
| | pad_token_id: int = 0, |
| | mask_token_id: int = 50256, |
| | eos_token_id: int = 50256, |
| | num_timesteps: int = 100, |
| | time_embed_dim: int = 128, |
| | **kwargs |
| | ): |
| | super().__init__(pad_token_id=pad_token_id, **kwargs) |
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | self.num_hidden_layers = num_hidden_layers |
| | self.num_attention_heads = num_attention_heads |
| | self.intermediate_size = intermediate_size |
| | self.hidden_dropout_prob = hidden_dropout_prob |
| | self.attention_probs_dropout_prob = attention_probs_dropout_prob |
| | self.max_position_embeddings = max_position_embeddings |
| | self.initializer_range = initializer_range |
| | self.layer_norm_eps = layer_norm_eps |
| | self.mask_token_id = mask_token_id |
| | self.eos_token_id = eos_token_id |
| | self.num_timesteps = num_timesteps |
| | self.time_embed_dim = time_embed_dim |
| |
|
| | class DiffusionLLM(PreTrainedModel): |
| | """Main Diffusion-LLM model class""" |
| | config_class = DiffusionConfig |
| | base_model_prefix = "diffusionLM" |
| |
|
| | def __init__(self, config: DiffusionConfig): |
| | super().__init__(config) |
| | self.model = LLaDAModel(config) |
| | self.init_weights() |
| |
|
| | def forward( |
| | self, |
| | input_ids=None, |
| | attention_mask=None, |
| | timesteps=None, |
| | labels=None, |
| | return_dict=True, |
| | ): |
| | outputs = self.model( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | timesteps=timesteps, |
| | labels=labels, |
| | ) |
| | |
| | return outputs |
| |
|
| | def generate( |
| | self, |
| | prompt=None, |
| | max_length=100, |
| | num_inference_steps=50, |
| | temperature=1.0, |
| | strategy='random', |
| | top_p=0.9, |
| | top_k=50, |
| | num_beams=5, |
| | return_scores=False, |
| | use_streaming=False, |
| | callback_fn=None |
| | ): |
| | """Unified generation interface""" |
| | if use_streaming: |
| | return self.generate_stream( |
| | prompt=prompt, |
| | max_length=max_length, |
| | num_inference_steps=num_inference_steps, |
| | temperature=temperature, |
| | strategy=strategy, |
| | top_p=top_p, |
| | top_k=top_k, |
| | num_beams=num_beams, |
| | callback_fn=callback_fn |
| | ) |
| | else: |
| | return self.model.generate( |
| | prompt=prompt, |
| | max_length=max_length, |
| | num_inference_steps=num_inference_steps, |
| | temperature=temperature, |
| | strategy=strategy, |
| | top_p=top_p, |
| | top_k=top_k, |
| | num_beams=num_beams, |
| | return_scores=return_scores |
| | ) |
| |
|
| | def generate_stream(self, **kwargs): |
| | """Streaming generation wrapper""" |
| | return self.model.generate_stream(**kwargs) |
| |
|
| | def prepare_inputs_for_generation(self, input_ids, **kwargs): |
| | """Prepare inputs for generation compatibility""" |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": kwargs.get("attention_mask", None), |
| | "timesteps": kwargs.get("timesteps", None), |
| | } |
| |
|
| | @staticmethod |
| | def _reorder_cache(past, beam_idx): |
| | """Reorder cache for beam search compatibility""" |
| | return past |
| |
|