Spaces:
Runtime error
Runtime error
File size: 4,828 Bytes
5dae26f 222b8b3 5dae26f e38f582 01ebc72 5dae26f 0fdb67e 5dae26f 7403d98 5dae26f 1365247 7403d98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import PIL.Image
import torch
from huggingface_hub import login
from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
import jax
import jax.numpy as jnp
import numpy as np
import functools
import spaces
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token, add_to_git_credential=True)
class PaliGemmaModel:
def __init__(self):
self.model_id = "google/paligemma-3b-mix-448"
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)
@spaces.GPU
def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
inputs = self.processor(text=text, images=image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device
with torch.inference_mode():
generated_ids = self.model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False
)
result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
return result[0][len(text):].lstrip("\n")
class VAEModel:
def __init__(self, model_path: str):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.params = self._get_params(model_path)
def _get_params(self, checkpoint_path):
"""Converts PyTorch checkpoint to Flax params."""
checkpoint = dict(np.load(checkpoint_path))
def transp(kernel):
return np.transpose(kernel, (2, 3, 1, 0))
def conv(name):
return {
'bias': checkpoint[name + '.bias'],
'kernel': transp(checkpoint[name + '.weight']),
}
def resblock(name):
return {
'Conv_0': conv(name + '.0'),
'Conv_1': conv(name + '.2'),
'Conv_2': conv(name + '.4'),
}
return {
'_embeddings': checkpoint['_vq_vae._embedding'],
'Conv_0': conv('decoder.0'),
'ResBlock_0': resblock('decoder.2.net'),
'ResBlock_1': resblock('decoder.3.net'),
'ConvTranspose_0': conv('decoder.4'),
'ConvTranspose_1': conv('decoder.6'),
'ConvTranspose_2': conv('decoder.8'),
'ConvTranspose_3': conv('decoder.10'),
'Conv_1': conv('decoder.12'),
}
def reconstruct_masks(self, codebook_indices):
quantized = self._quantized_values_from_codebook_indices(codebook_indices)
return self._decoder().apply({'params': self.params}, quantized)
def _quantized_values_from_codebook_indices(self, codebook_indices):
batch_size, num_tokens = codebook_indices.shape
assert num_tokens == 16, codebook_indices.shape
unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape
encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0)
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
return encodings
@functools.cache
def _decoder(self):
class ResBlock(nn.Module):
features: int
@nn.compact
def __call__(self, x):
original_x = x
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
x = nn.relu(x)
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
return x + original_x
class Decoder(nn.Module):
"""Upscales quantized vectors to mask."""
@nn.compact
def __call__(self, x):
num_res_blocks = 2
dim = 128
num_upsample_layers = 4
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
x = nn.relu(x)
for _ in range(num_res_blocks):
x = ResBlock(features=dim)(x)
for _ in range(num_upsample_layers):
x = nn.ConvTranspose(
features=dim,
kernel_size=(4, 4),
strides=(2, 2),
padding=2,
transpose_kernel=True,
)(x)
x = nn.relu(x)
dim //= 2
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
return x
return jax.jit(Decoder().apply, backend='cpu')
|