| import argparse |
| import logging |
| import os |
| import pathlib |
| from typing import List, NoReturn |
| import lightning.pytorch as pl |
| from lightning.pytorch.strategies import DDPStrategy |
| from torch.utils.tensorboard import SummaryWriter |
| from data.datamodules import * |
| from utils import create_logging, parse_yaml |
| from models.resunet import * |
| from losses import get_loss_function |
| from models.audiosep import AudioSep, get_model_class |
| from data.waveform_mixers import SegmentMixer |
| from models.clap_encoder import CLAP_Encoder |
| from callbacks.base import CheckpointEveryNSteps |
| from optimizers.lr_schedulers import get_lr_lambda |
|
|
|
|
| def get_dirs( |
| workspace: str, |
| filename: str, |
| config_yaml: str, |
| devices_num: int |
| ) -> List[str]: |
| r"""Get directories and paths. |
| |
| Args: |
| workspace (str): directory of workspace |
| filename (str): filename of current .py file. |
| config_yaml (str): config yaml path |
| devices_num (int): 0 for cpu and 8 for training with 8 GPUs |
| |
| Returns: |
| checkpoints_dir (str): directory to save checkpoints |
| logs_dir (str), directory to save logs |
| tf_logs_dir (str), directory to save TensorBoard logs |
| statistics_path (str), directory to save statistics |
| """ |
| |
| os.makedirs(workspace, exist_ok=True) |
|
|
| yaml_name = pathlib.Path(config_yaml).stem |
|
|
| |
| checkpoints_dir = os.path.join( |
| workspace, |
| "checkpoints", |
| filename, |
| "{},devices={}".format(yaml_name, devices_num), |
| ) |
| os.makedirs(checkpoints_dir, exist_ok=True) |
|
|
| |
| logs_dir = os.path.join( |
| workspace, |
| "logs", |
| filename, |
| "{},devices={}".format(yaml_name, devices_num), |
| ) |
| os.makedirs(logs_dir, exist_ok=True) |
|
|
| |
| create_logging(logs_dir, filemode="w") |
| logging.info(args) |
|
|
| tf_logs_dir = os.path.join( |
| workspace, |
| "tf_logs", |
| filename, |
| "{},devices={}".format(yaml_name, devices_num), |
| ) |
|
|
| |
| statistics_path = os.path.join( |
| workspace, |
| "statistics", |
| filename, |
| "{},devices={}".format(yaml_name, devices_num), |
| "statistics.pkl", |
| ) |
| os.makedirs(os.path.dirname(statistics_path), exist_ok=True) |
|
|
| return checkpoints_dir, logs_dir, tf_logs_dir, statistics_path |
|
|
| |
| def get_data_module( |
| config_yaml: str, |
| num_workers: int, |
| batch_size: int, |
| ) -> DataModule: |
| r"""Create data_module. Mini-batch data can be obtained by: |
| |
| code-block:: python |
| |
| data_module.setup() |
| |
| for batch_data_dict in data_module.train_dataloader(): |
| print(batch_data_dict.keys()) |
| break |
| |
| Args: |
| workspace: str |
| config_yaml: str |
| num_workers: int, e.g., 0 for non-parallel and 8 for using cpu cores |
| for preparing data in parallel |
| distributed: bool |
| |
| Returns: |
| data_module: DataModule |
| """ |
|
|
| |
| configs = parse_yaml(config_yaml) |
| sampling_rate = configs['data']['sampling_rate'] |
| segment_seconds = configs['data']['segment_seconds'] |
| |
| |
| datafiles = configs['data']['datafiles'] |
| |
| |
| dataset = AudioTextDataset( |
| datafiles=datafiles, |
| sampling_rate=sampling_rate, |
| max_clip_len=segment_seconds, |
| ) |
| |
| |
| |
| data_module = DataModule( |
| train_dataset=dataset, |
| num_workers=num_workers, |
| batch_size=batch_size |
| ) |
|
|
| return data_module |
|
|
|
|
| def train(args) -> NoReturn: |
| r"""Train, evaluate, and save checkpoints. |
| |
| Args: |
| workspace: str, directory of workspace |
| gpus: int, number of GPUs to train |
| config_yaml: str |
| """ |
|
|
| |
| workspace = args.workspace |
| config_yaml = args.config_yaml |
| filename = args.filename |
|
|
| devices_num = torch.cuda.device_count() |
| |
| configs = parse_yaml(config_yaml) |
|
|
| |
| max_mix_num = configs['data']['max_mix_num'] |
| sampling_rate = configs['data']['sampling_rate'] |
| lower_db = configs['data']['loudness_norm']['lower_db'] |
| higher_db = configs['data']['loudness_norm']['higher_db'] |
|
|
| |
| query_net = configs['model']['query_net'] |
| model_type = configs['model']['model_type'] |
| input_channels = configs['model']['input_channels'] |
| output_channels = configs['model']['output_channels'] |
| condition_size = configs['model']['condition_size'] |
| use_text_ratio = configs['model']['use_text_ratio'] |
| |
| |
| num_nodes = configs['train']['num_nodes'] |
| batch_size = configs['train']['batch_size_per_device'] |
| sync_batchnorm = configs['train']['sync_batchnorm'] |
| num_workers = configs['train']['num_workers'] |
| loss_type = configs['train']['loss_type'] |
| optimizer_type = configs["train"]["optimizer"]["optimizer_type"] |
| learning_rate = float(configs['train']["optimizer"]['learning_rate']) |
| lr_lambda_type = configs['train']["optimizer"]['lr_lambda_type'] |
| warm_up_steps = configs['train']["optimizer"]['warm_up_steps'] |
| reduce_lr_steps = configs['train']["optimizer"]['reduce_lr_steps'] |
| save_step_frequency = configs['train']['save_step_frequency'] |
| resume_checkpoint_path = args.resume_checkpoint_path |
| if resume_checkpoint_path == "": |
| resume_checkpoint_path = None |
| else: |
| logging.info(f'Finetuning AudioSep with checkpoint [{resume_checkpoint_path}]') |
|
|
| |
| checkpoints_dir, logs_dir, tf_logs_dir, statistics_path = get_dirs( |
| workspace, filename, config_yaml, devices_num, |
| ) |
|
|
| logging.info(configs) |
|
|
| |
| data_module = get_data_module( |
| config_yaml=config_yaml, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| ) |
| |
| |
| Model = get_model_class(model_type=model_type) |
|
|
| ss_model = Model( |
| input_channels=input_channels, |
| output_channels=output_channels, |
| condition_size=condition_size, |
| ) |
|
|
| |
| loss_function = get_loss_function(loss_type) |
|
|
| segment_mixer = SegmentMixer( |
| max_mix_num=max_mix_num, |
| lower_db=lower_db, |
| higher_db=higher_db |
| ) |
|
|
| |
| if query_net == 'CLAP': |
| query_encoder = CLAP_Encoder() |
| else: |
| raise NotImplementedError |
|
|
| lr_lambda_func = get_lr_lambda( |
| lr_lambda_type=lr_lambda_type, |
| warm_up_steps=warm_up_steps, |
| reduce_lr_steps=reduce_lr_steps, |
| ) |
|
|
| |
| pl_model = AudioSep( |
| ss_model=ss_model, |
| waveform_mixer=segment_mixer, |
| query_encoder=query_encoder, |
| loss_function=loss_function, |
| optimizer_type=optimizer_type, |
| learning_rate=learning_rate, |
| lr_lambda_func=lr_lambda_func, |
| use_text_ratio=use_text_ratio |
| ) |
|
|
| checkpoint_every_n_steps = CheckpointEveryNSteps( |
| checkpoints_dir=checkpoints_dir, |
| save_step_frequency=save_step_frequency, |
| ) |
|
|
| summary_writer = SummaryWriter(log_dir=tf_logs_dir) |
|
|
| callbacks = [checkpoint_every_n_steps] |
|
|
| trainer = pl.Trainer( |
| accelerator='auto', |
| devices='auto', |
| strategy='ddp_find_unused_parameters_true', |
| num_nodes=num_nodes, |
| precision="32-true", |
| logger=None, |
| callbacks=callbacks, |
| fast_dev_run=False, |
| max_epochs=-1, |
| log_every_n_steps=50, |
| use_distributed_sampler=True, |
| sync_batchnorm=sync_batchnorm, |
| num_sanity_val_steps=2, |
| enable_checkpointing=False, |
| enable_progress_bar=True, |
| enable_model_summary=True, |
| ) |
|
|
| |
| trainer.fit( |
| model=pl_model, |
| train_dataloaders=None, |
| val_dataloaders=None, |
| datamodule=data_module, |
| ckpt_path=resume_checkpoint_path, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
|
|
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--workspace", type=str, required=True, help="Directory of workspace." |
| ) |
| parser.add_argument( |
| "--config_yaml", |
| type=str, |
| required=True, |
| help="Path of config file for training.", |
| ) |
|
|
| parser.add_argument( |
| "--resume_checkpoint_path", |
| type=str, |
| required=True, |
| default='', |
| help="Path of pretrained checkpoint for finetuning.", |
| ) |
|
|
| args = parser.parse_args() |
| args.filename = pathlib.Path(__file__).stem |
|
|
| train(args) |