| """ |
| Simple example usage of CosmicFish model (local model) |
| """ |
| import torch |
| from transformers import GPT2Tokenizer |
| from modeling_cosmicfish import CosmicFish, CosmicConfig |
| from safetensors.torch import load_file |
| import json |
|
|
| def load_cosmicfish(model_dir): |
| """Load CosmicFish model and tokenizer""" |
| |
| with open(f"{model_dir}/config.json", "r") as f: |
| config_dict = json.load(f) |
|
|
| |
| config = CosmicConfig( |
| vocab_size=config_dict["vocab_size"], |
| block_size=config_dict["block_size"], |
| n_layer=config_dict["n_layer"], |
| n_head=config_dict["n_head"], |
| n_embd=config_dict["n_embd"], |
| bias=config_dict["bias"], |
| dropout=0.0, |
| use_rotary=config_dict["use_rotary"], |
| use_swiglu=config_dict["use_swiglu"], |
| use_gqa=config_dict["use_gqa"], |
| n_query_groups=config_dict["n_query_groups"], |
| use_qk_norm=config_dict.get("use_qk_norm", False) |
| ) |
|
|
| |
| model = CosmicFish(config) |
| state_dict = load_file(f"{model_dir}/model.safetensors") |
|
|
| |
| if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
| state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
|
|
| |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
| return model, tokenizer |
|
|
| def simple_generate(model, tokenizer, prompt, max_tokens=50, temperature=0.7): |
| """Generate text from a prompt""" |
| inputs = tokenizer.encode(prompt, return_tensors="pt") |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| top_k=40 |
| ) |
|
|
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| if __name__ == "__main__": |
| |
| print("Loading CosmicFish...") |
| model, tokenizer = load_cosmicfish("./") |
| print(f"Model loaded! ({model.get_num_params()/1e6:.1f}M parameters)") |
|
|
| |
| prompts = [ |
| "What is climate change?", |
| "Write a poem", |
| "Define ML" |
| ] |
|
|
| |
| for prompt in prompts: |
| print(f"\nPrompt: {prompt}") |
| response = simple_generate(model, tokenizer, prompt, max_tokens=30) |
| print(f"Response: {response}") |
|
|