File size: 427 Bytes
14ce5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
def forward(self, input):
quant, diff, [_, _, img_toks] = self.encode(input)
batch_size, height, width, n_channel = (
input.shape[0],
quant.shape[-1],
quant.shape[-2],
quant.shape[-3],
)
codebook_entry = self.quantize.get_codebook_entry(
img_toks, (batch_size, n_channel, height, width)
)
pixels = self.decode(codebook_entry)
return pixels, img_toks, quant
|