| | import timeit |
| | import numpy as np |
| | import os |
| | import os.path as osp |
| | import shutil |
| | import copy |
| | import torch |
| | import torch.nn as nn |
| | import torch.distributed as dist |
| | from .cfg_holder import cfg_unique_holder as cfguh |
| | from . import sync |
| |
|
| | print_console_local_rank0_only = True |
| |
|
| | def print_log(*console_info): |
| | local_rank = sync.get_rank('local') |
| | if print_console_local_rank0_only and (local_rank!=0): |
| | return |
| | console_info = [str(i) for i in console_info] |
| | console_info = ' '.join(console_info) |
| | print(console_info) |
| |
|
| | if local_rank!=0: |
| | return |
| |
|
| | log_file = None |
| | try: |
| | log_file = cfguh().cfg.train.log_file |
| | except: |
| | try: |
| | log_file = cfguh().cfg.eval.log_file |
| | except: |
| | return |
| | if log_file is not None: |
| | with open(log_file, 'a') as f: |
| | f.write(console_info + '\n') |
| |
|
| | class distributed_log_manager(object): |
| | def __init__(self): |
| | self.sum = {} |
| | self.cnt = {} |
| | self.time_check = timeit.default_timer() |
| |
|
| | cfgt = cfguh().cfg.train |
| | use_tensorboard = getattr(cfgt, 'log_tensorboard', False) |
| |
|
| | self.ddp = sync.is_ddp() |
| | self.rank = sync.get_rank('local') |
| | self.world_size = sync.get_world_size('local') |
| |
|
| | self.tb = None |
| | if use_tensorboard and (self.rank==0): |
| | import tensorboardX |
| | monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard') |
| | self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir)) |
| |
|
| | def accumulate(self, n, **data): |
| | if n < 0: |
| | raise ValueError |
| |
|
| | for itemn, di in data.items(): |
| | if itemn in self.sum: |
| | self.sum[itemn] += di * n |
| | self.cnt[itemn] += n |
| | else: |
| | self.sum[itemn] = di * n |
| | self.cnt[itemn] = n |
| |
|
| | def get_mean_value_dict(self): |
| | value_gather = [ |
| | self.sum[itemn]/self.cnt[itemn] \ |
| | for itemn in sorted(self.sum.keys()) ] |
| |
|
| | value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank) |
| | if self.ddp: |
| | dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM) |
| | value_gather_tensor /= self.world_size |
| |
|
| | mean = {} |
| | for idx, itemn in enumerate(sorted(self.sum.keys())): |
| | mean[itemn] = value_gather_tensor[idx].item() |
| | return mean |
| |
|
| | def tensorboard_log(self, step, data, mode='train', **extra): |
| | if self.tb is None: |
| | return |
| | if mode == 'train': |
| | self.tb.add_scalar('other/epochn', extra['epochn'], step) |
| | if 'lr' in extra: |
| | self.tb.add_scalar('other/lr', extra['lr'], step) |
| | for itemn, di in data.items(): |
| | if itemn.find('loss') == 0: |
| | self.tb.add_scalar('loss/'+itemn, di, step) |
| | elif itemn == 'Loss': |
| | self.tb.add_scalar('Loss', di, step) |
| | else: |
| | self.tb.add_scalar('other/'+itemn, di, step) |
| | elif mode == 'eval': |
| | if isinstance(data, dict): |
| | for itemn, di in data.items(): |
| | self.tb.add_scalar('eval/'+itemn, di, step) |
| | else: |
| | self.tb.add_scalar('eval', data, step) |
| | return |
| |
|
| | def train_summary(self, itern, epochn, samplen, lr, tbstep=None): |
| | console_info = [ |
| | 'Iter:{}'.format(itern), |
| | 'Epoch:{}'.format(epochn), |
| | 'Sample:{}'.format(samplen),] |
| |
|
| | if lr is not None: |
| | console_info += ['LR:{:.4E}'.format(lr)] |
| |
|
| | mean = self.get_mean_value_dict() |
| |
|
| | tbstep = itern if tbstep is None else tbstep |
| | self.tensorboard_log( |
| | tbstep, mean, mode='train', |
| | itern=itern, epochn=epochn, lr=lr) |
| |
|
| | loss = mean.pop('Loss') |
| | mean_info = ['Loss:{:.4f}'.format(loss)] + [ |
| | '{}:{:.4f}'.format(itemn, mean[itemn]) \ |
| | for itemn in sorted(mean.keys()) \ |
| | if itemn.find('loss') == 0 |
| | ] |
| | console_info += mean_info |
| | console_info.append('Time:{:.2f}s'.format( |
| | timeit.default_timer() - self.time_check)) |
| | return ' , '.join(console_info) |
| |
|
| | def clear(self): |
| | self.sum = {} |
| | self.cnt = {} |
| | self.time_check = timeit.default_timer() |
| |
|
| | def tensorboard_close(self): |
| | if self.tb is not None: |
| | self.tb.close() |
| |
|
| | |
| |
|
| | def torch_to_numpy(*argv): |
| | if len(argv) > 1: |
| | data = list(argv) |
| | else: |
| | data = argv[0] |
| |
|
| | if isinstance(data, torch.Tensor): |
| | return data.to('cpu').detach().numpy() |
| |
|
| | elif isinstance(data, (list, tuple)): |
| | out = [] |
| | for di in data: |
| | out.append(torch_to_numpy(di)) |
| | return out |
| |
|
| | elif isinstance(data, dict): |
| | out = {} |
| | for ni, di in data.items(): |
| | out[ni] = torch_to_numpy(di) |
| | return out |
| | |
| | else: |
| | return data |
| |
|