| import numpy as np |
| import random |
| import string |
| from torch.utils.data import Dataset, Subset |
|
|
| class DummyData(Dataset): |
| def __init__(self, length, size): |
| self.length = length |
| self.size = size |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, i): |
| x = np.random.randn(*self.size) |
| letters = string.ascii_lowercase |
| y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) |
| return {"jpg": x, "txt": y} |
|
|
|
|
| class DummyDataWithEmbeddings(Dataset): |
| def __init__(self, length, size, emb_size): |
| self.length = length |
| self.size = size |
| self.emb_size = emb_size |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, i): |
| x = np.random.randn(*self.size) |
| y = np.random.randn(*self.emb_size).astype(np.float32) |
| return {"jpg": x, "txt": y} |
|
|
|
|