| """ |
| Example usage of CosmicFish model (using safetensors) |
| """ |
| import torch |
| from transformers import GPT2Tokenizer |
| from modeling_cosmicfish import CosmicFish, CosmicConfig |
| from safetensors.torch import load_file |
| import json |
|
|
| def load_cosmicfish(model_dir): |
| """Load CosmicFish model and tokenizer""" |
| |
| with open(f"{model_dir}/config.json", "r") as f: |
| config_dict = json.load(f) |
|
|
| |
| config = CosmicConfig( |
| vocab_size=config_dict["vocab_size"], |
| block_size=config_dict["block_size"], |
| n_layer=config_dict["n_layer"], |
| n_head=config_dict["n_head"], |
| n_embd=config_dict["n_embd"], |
| bias=config_dict["bias"], |
| dropout=0.0, |
| use_rotary=config_dict["use_rotary"], |
| use_swiglu=config_dict["use_swiglu"], |
| use_gqa=config_dict["use_gqa"], |
| n_query_groups=config_dict["n_query_groups"], |
| use_qk_norm=config_dict["use_qk_norm"] |
| ) |
|
|
| |
| model = CosmicFish(config) |
|
|
| |
| state_dict = load_file(f"{model_dir}/model.safetensors") |
|
|
| |
| if 'lm_head.weight' not in state_dict and 'transformer.wte.weight' in state_dict: |
| print("Weight sharing detected: tying lm_head.weight to transformer.wte.weight") |
| state_dict['lm_head.weight'] = state_dict['transformer.wte.weight'] |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
|
|
| |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
|
| return model, tokenizer |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|