def forward(self, input): | |
quant, diff, [_, _, img_toks] = self.encode(input) | |
batch_size, height, width, n_channel = ( | |
input.shape[0], | |
quant.shape[-1], | |
quant.shape[-2], | |
quant.shape[-3], | |
) | |
codebook_entry = self.quantize.get_codebook_entry( | |
img_toks, (batch_size, n_channel, height, width) | |
) | |
pixels = self.decode(codebook_entry) | |
return pixels, img_toks, quant | |