| import os |
|
|
|
|
| |
| from builtin_architecture import make_model, make_model_custom |
| from dataset import fromDataset, get_dataloader, TextCorpusDataset |
| import torch |
| from tqdm import tqdm, trange |
| from logger import init_logger, flush |
| import torchvision |
| from trainingmanager import TrainingManager |
| import torch.nn as nn |
|
|
|
|
| def train_model( |
| experiment_directory, |
| trainset, |
| testset, |
| epochs, |
| additional_epochs, |
| model_params=None, |
| schedule=False, |
| **kwargs, |
| ): |
| os.system(f"caffeinate -is -w {os.getpid()} &") |
|
|
| if model_params is None: |
| model_params = {} |
|
|
| device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
| dataloader = get_dataloader(trainset) |
|
|
| testloader = get_dataloader(testset) |
| if model_params == {}: |
| net = make_model() |
| else: |
| net = make_model_custom(**model_params) |
| net.to(device) |
|
|
| trainer = TrainingManager( |
| net=net, |
| dir=experiment_directory, |
| dataloader=dataloader, |
| device=device, |
| trainstep_checkin_interval=100, |
| epochs=epochs, |
| val_dataloader=testloader, |
| ) |
|
|
| |
|
|
| for batch, attn_mask in dataloader: |
| init_logger( |
| net, |
| dir=os.path.join(experiment_directory, "tensorboard"), |
| ) |
| break |
| if schedule: |
| trainer.train_curriculum(**kwargs) |
| else: |
| trainer.train() |
|
|
| if additional_epochs > 0: |
| print(f"Running additional {additional_epochs} epochs") |
| additional_trainer = TrainingManager( |
| net=net, |
| dir=experiment_directory, |
| dataloader=dataloader, |
| device=device, |
| trainstep_checkin_interval=100, |
| epochs=epochs + additional_epochs, |
| val_dataloader=testloader, |
| ) |
| additional_trainer.train() |
| |
| flush() |
|
|
| os.system("bash safe_cleanup.sh") |
|
|
|
|
| def run_experiment(experiment_directory, epochs, additional_epochs, trainset, testset, del_runs, **kwargs): |
| train_model(experiment_directory, trainset, testset, epochs, additional_epochs, schedule=True, **kwargs) |
| if del_runs: |
| os.system(f"rm -r {experiment_directory}/ckpt/*.pt") |
|
|
|
|
| if __name__ == "__main__": |
| del_runs = False |
| if del_runs: |
| del_runs = ( |
| del_runs and input("Confirm that this will delete checkpoints: ") == "y" |
| ) |
| if not del_runs: |
| print("Exiting") |
| exit() |
|
|
| parent_directory = "runs/code-decoder-v31-mega-licensed-1" |
|
|
| Curriculum = TrainingManager.get_curriculum_enum() |
|
|
| experiments = [ |
| ( |
| "curriculum-noloss", |
| {"curriculum_type": Curriculum.CURRICULUM, "loss_based": False}, |
| ), |
| ( |
| "curriculum-loss", |
| {"curriculum_type": Curriculum.CURRICULUM, "loss_based": True}, |
| ), |
| ("noop", {"curriculum_type": Curriculum.NOOP, "loss_based": False}), |
| |
| ( |
| "anticurriculum", |
| {"curriculum_type": Curriculum.ANTICURRICULUM, "loss_based": False}, |
| ), |
| ( |
| "anticurriculum-loss", |
| {"curriculum_type": Curriculum.ANTICURRICULUM, "loss_based": True}, |
| ), |
| |
| ("hybrid", {"curriculum_type": Curriculum.HYBRID, "loss_based": False}), |
| ("hybrid-loss", {"curriculum_type": Curriculum.HYBRID, "loss_based": True}), |
|
|
| ("sequential", {"curriculum_type": Curriculum.SEQUENTIAL, "loss_based": False}), |
| ( |
| "sequential-loss", |
| {"curriculum_type": Curriculum.SEQUENTIAL, "loss_based": True}, |
| ), |
| ] |
|
|
| EPOCHS = 10 |
| ADDITIONAL_EPOCHS = 20 |
| trainset, testset = fromDataset( |
| TextCorpusDataset( |
| root_dir=os.path.expanduser( |
| "~/torch_datasets/github-python/mega_licensed_corpus" |
| ), |
| vocab_size=33819, |
| IS_CODE=True, |
| IS_CUSTOM=True, |
| max_length=256, |
| sliding_window=False, |
| stride=10, |
| get_rarity_score=True, |
| get_entropy_score=False |
| ) |
| ) |
| |
|
|
| for experiment_name, params in experiments: |
| experiment_directory = os.path.join(parent_directory, experiment_name) |
| |
| print(f"Running experiment: {experiment_name}") |
| print(f"Params: {params}") |
| |
| |
|
|
| run_experiment( |
| experiment_directory, |
| EPOCHS, |
| ADDITIONAL_EPOCHS, |
| trainset, |
| testset, |
| del_runs, |
| **params, |
| ) |
|
|
| |
| import gc |
| gc.collect() |
| for obj in gc.get_objects(): |
| try: |
| if torch.is_tensor(obj): |
| del obj |
| except: |
| pass |
| if torch.backends.mps.is_available(): |
| torch._C._mps_emptyCache() |
| torch.mps.empty_cache() |
|
|
|
|