| """ |
| Sample from a trained model |
| """ |
| import os |
| import pickle |
| from contextlib import nullcontext |
| import torch |
| import tiktoken |
| from model import GPTConfig, GPT |
| from tqdm import tqdm |
| import random |
| import numpy as np |
| from transformers import AutoTokenizer |
| from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode |
| import argparse |
| import itertools |
| import random |
|
|
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument("--init_from", type=str, default="resume", help="Directory of raw data & output files") |
| parser.add_argument("--out_path", type=str, required=True) |
| parser.add_argument("--num_samples", type=int, required=False, default=100000) |
| parser.add_argument("--max_new_tokens", type=int, required=True, help="number of tokens generated in each sample") |
| parser.add_argument("--strategy",type=str, required=False,default='top_k',help="should be in ['greedy_search', 'sampling', 'top_k', 'beam_search']") |
| parser.add_argument("--beam_size",type=int, required=False,default=3,help="beam size for beam search") |
| parser.add_argument("--temperature",type=float, required=False,default=1.0,help="1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions") |
| parser.add_argument("--top_k",type=int, required=False,default=20,help="retain only the top_k most likely tokens, clamp others to have 0 probability") |
| parser.add_argument("--ckpt_path",type=str, required=True,help="path to a checkpoint/model") |
| parser.add_argument("--tokenizer_path",type=str, required=True,help="path to a tokenizer directory") |
| parser.add_argument("--start",type=str, required=False,default="<|endoftext|>") |
| parser.add_argument("--repetition_penalty",type=float, required=False,default=1.0) |
| parser.add_argument("--shuffle_token", action='store_true', help="Enable shuffling of tokens before decoding") |
| parser.add_argument("--fasta", action='store_true', default=True, help="Enable writing output in FASTA format") |
|
|
| args = parser.parse_args() |
| init_from = args.init_from |
| out_path = args.out_path |
| num_samples = args.num_samples |
| max_new_tokens = args.max_new_tokens |
| strategy = args.strategy |
| assert strategy in ['greedy_search', 'sampling', 'top_k', 'beam_search'] |
| beam_size = args.beam_size |
| temperature = args.temperature |
| top_k = args.top_k |
| ckpt_path = args.ckpt_path |
| tokenizer_path = args.tokenizer_path |
| start = args.start |
| repetition_penalty = args.repetition_penalty |
| fasta = args.fasta |
|
|
|
|
| |
| seed = random.randint(1,6666) |
| device = 'cuda' |
| dtype = 'float32' |
| |
| compile = False |
|
|
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) |
| |
|
|
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| device_type = 'cuda' if 'cuda' in device else 'cpu' |
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
| ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
| |
| if init_from == 'resume': |
| |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| gptconf = GPTConfig(**checkpoint['model_args']) |
| model = GPT(gptconf) |
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k,v in list(state_dict.items()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| model.load_state_dict(state_dict) |
| elif init_from.startswith('gpt2'): |
| |
| model = GPT.from_pretrained(init_from, dict(dropout=0.0)) |
|
|
| model.eval() |
| model.to(device) |
| if compile: |
| model = torch.compile(model) |
|
|
| |
| load_meta = False |
| encode = tokenizer.encode |
| decode = tokenizer.decode |
|
|
| fasta_out_path = os.path.splitext(out_path)[0] + ".fasta" if fasta else None |
|
|
| if strategy in["sampling", "top_k"]: |
| start_ids = encode("".join(start)) |
| x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
|
|
| with open(out_path, 'a') as f: |
| with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f: |
| with torch.no_grad(): |
| with ctx: |
| for k in tqdm(range(num_samples), desc="Generating samples"): |
| token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty)[0].tolist() |
| |
| |
| if args.shuffle_token: |
| random.shuffle(token_sequence) |
|
|
| y = decode(token_sequence).replace(' ', '') |
| |
| f.write(y) |
| f.flush() |
|
|
|
|
| if fasta: |
| fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n" |
| fasta_f.write(fasta_entry.strip() + '\n') |
| fasta_f.flush() |
|
|
|
|
| elif strategy in ["beam_search", "greedy_search"]: |
| with open(out_path, 'a') as f: |
| with open(fasta_out_path, 'a') if fasta else nullcontext() as fasta_f: |
| with torch.no_grad(): |
| with ctx: |
| start = '<|endoftext|>' |
| start_ids = encode(start) |
| x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) |
|
|
| token_sequence = model.generate(x, max_new_tokens, strategy=strategy, temperature=temperature, top_k=top_k, repetition_penalty=repetition_penalty, beam_size=beam_size)[0].tolist() |
|
|
| y = decode(token_sequence).replace(' ', '') |
| f.write(y) |
| f.flush() |
|
|
|
|
| if fasta: |
| fasta_entry = f">sample_{k}\n{y.replace(' ', '')}\n" |
| fasta_f.write(fasta_entry.strip() + '\n') |
| fasta_f.flush() |