| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
|
|
|
|
| def dice(x, y): |
| intersect = np.sum(np.sum(np.sum(x * y))) |
| y_sum = np.sum(np.sum(np.sum(y))) |
| if y_sum == 0: |
| return 0.0 |
| x_sum = np.sum(np.sum(np.sum(x))) |
| return 2 * intersect / (x_sum + y_sum) |
|
|
|
|
| class AverageMeter(object): |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = np.where(self.count > 0, self.sum / self.count, self.sum) |
|
|
|
|
| def distributed_all_gather( |
| tensor_list, valid_batch_size=None, out_numpy=False, world_size=None, no_barrier=False, is_valid=None |
| ): |
| if world_size is None: |
| world_size = torch.distributed.get_world_size() |
| if valid_batch_size is not None: |
| valid_batch_size = min(valid_batch_size, world_size) |
| elif is_valid is not None: |
| is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device) |
| if not no_barrier: |
| torch.distributed.barrier() |
| tensor_list_out = [] |
| with torch.no_grad(): |
| if is_valid is not None: |
| is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)] |
| torch.distributed.all_gather(is_valid_list, is_valid) |
| is_valid = [x.item() for x in is_valid_list] |
| for tensor in tensor_list: |
| gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] |
| torch.distributed.all_gather(gather_list, tensor) |
| if valid_batch_size is not None: |
| gather_list = gather_list[:valid_batch_size] |
| elif is_valid is not None: |
| gather_list = [g for g, v in zip(gather_list, is_valid_list) if v] |
| if out_numpy: |
| gather_list = [t.cpu().numpy() for t in gather_list] |
| tensor_list_out.append(gather_list) |
| return tensor_list_out |
|
|