| |
| |
| |
| |
| |
| |
|
|
| import torch.nn as nn |
|
|
| from repcodec.layers.vq_module import ResidualVQ |
|
|
|
|
| class Quantizer(nn.Module): |
| def __init__( |
| self, |
| code_dim: int, |
| codebook_num: int, |
| codebook_size: int, |
| ): |
| super().__init__() |
| self.codebook = ResidualVQ( |
| dim=code_dim, |
| num_quantizers=codebook_num, |
| codebook_size=codebook_size |
| ) |
|
|
| def initial(self): |
| self.codebook.initial() |
|
|
| def forward(self, z): |
| zq, vqloss, perplexity = self.codebook(z.transpose(2, 1)) |
| zq = zq.transpose(2, 1) |
| return zq, vqloss, perplexity |
|
|
| def inference(self, z): |
| zq, indices = self.codebook.forward_index(z.transpose(2, 1)) |
| zq = zq.transpose(2, 1) |
| return zq, indices |
|
|
| def encode(self, z): |
| zq, indices = self.codebook.forward_index(z.transpose(2, 1), flatten_idx=True) |
| return zq, indices |
|
|
| def decode(self, indices): |
| z = self.codebook.lookup(indices) |
| return z |
|
|