VTBench / src /vqvaes /janus_pro /janus_pro.py
huaweilin's picture
update
14ce5a9
raw
history blame contribute delete
427 Bytes
def forward(self, input):
quant, diff, [_, _, img_toks] = self.encode(input)
batch_size, height, width, n_channel = (
input.shape[0],
quant.shape[-1],
quant.shape[-2],
quant.shape[-3],
)
codebook_entry = self.quantize.get_codebook_entry(
img_toks, (batch_size, n_channel, height, width)
)
pixels = self.decode(codebook_entry)
return pixels, img_toks, quant