File size: 427 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def forward(self, input):
    quant, diff, [_, _, img_toks] = self.encode(input)

    batch_size, height, width, n_channel = (
        input.shape[0],
        quant.shape[-1],
        quant.shape[-2],
        quant.shape[-3],
    )
    codebook_entry = self.quantize.get_codebook_entry(
        img_toks, (batch_size, n_channel, height, width)
    )
    pixels = self.decode(codebook_entry)

    return pixels, img_toks, quant