| import numpy as np |
| import torch |
| import copy |
| from collections import Counter |
| from torch.utils.data import DataLoader |
|
|
| def random_update(datasets, buffer): |
|
|
| images = np.array(datasets.images + buffer.images) |
| labels = np.array(datasets.labels + buffer.labels) |
| perm = np.random.permutation(len(labels)) |
|
|
| images, labels = images[perm[:buffer.buffer_size]], labels[perm[:buffer.buffer_size]] |
|
|
| buffer.images = images.tolist() |
| buffer.labels = labels.tolist() |
|
|
| def herding_update(datasets, buffer, feature_extractor, device): |
|
|
| print("Using Herding Update Strategy") |
|
|
| per_classes = buffer.buffer_size // buffer.total_classes |
|
|
| selected_images, selected_labels = [], [] |
| images = np.array(datasets.images + buffer.images) |
| labels = np.array(datasets.labels + buffer.labels) |
|
|
| for cls in range(buffer.total_classes): |
| cls_images_idx = np.where(labels == cls) |
| cls_images, cls_labels = images[cls_images_idx], labels[cls_images_idx] |
|
|
| cls_selected_images, cls_selected_labels = construct_examplar(copy.copy(datasets), cls_images, cls_labels, feature_extractor, per_classes, device) |
| selected_images.extend(cls_selected_images) |
| selected_labels.extend(cls_selected_labels) |
|
|
| label_counter = Counter(buffer.labels) |
| print("\nBuffer composition per class:") |
| for cls in sorted(label_counter.keys()): |
| print(f" Class {cls:3d} : {label_counter[cls]} samples") |
|
|
| buffer.images, buffer.labels = selected_images, selected_labels |
|
|
| def construct_examplar(datasets, images, labels, feature_extractor, per_classes, device): |
| if len(images) <= per_classes: |
| print(labels[0], len(images), per_classes) |
| return images, labels |
| |
| datasets.images, datasets.labels = images, labels |
| dataloader = DataLoader(datasets, shuffle = False, batch_size = 32, drop_last = False) |
|
|
| with torch.no_grad(): |
| features = [] |
| for data in dataloader: |
| imgs = data['image'].to(device) |
| features.append(feature_extractor(imgs)['features'].cpu().numpy().tolist()) |
|
|
| features = np.concatenate(features) |
| selected_images, selected_labels = [], [] |
| selected_features = [] |
| class_mean = np.mean(features, axis=0) |
|
|
| for k in range(1, per_classes+1): |
| if len(selected_features) == 0: |
| S = np.zeros_like(features[0]) |
| else: |
| S = np.mean(np.array(selected_features), axis=0) |
|
|
|
|
| mu_p = (S + features) / k |
| i = np.argmin(np.sqrt(np.sum((class_mean - mu_p) ** 2, axis=1))) |
|
|
| selected_images.append(images[i]) |
| selected_labels.append(labels[i]) |
| selected_features.append(features[i]) |
|
|
| features = np.delete(features, i, axis=0) |
| images = np.delete(images, i) |
| labels = np.delete(labels, i) |
|
|
| return selected_images, selected_labels |
|
|
|
|
|
|
| |
| |
|
|