dwb2023 commited on
Commit
839e917
·
verified ·
1 Parent(s): 01ebc72

Delete inference.py

Browse files
Files changed (1) hide show
  1. inference.py +0 -130
inference.py DELETED
@@ -1,130 +0,0 @@
1
- import os
2
- import PIL.Image
3
- import torch
4
- from huggingface_hub import login
5
- from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor
6
- import jax
7
- import jax.numpy as jnp
8
- import numpy as np
9
- import functools
10
- import spaces
11
-
12
- hf_token = os.getenv("HF_TOKEN")
13
- login(token=hf_token, add_to_git_credential=True)
14
-
15
- class PaliGemmaModel:
16
- def __init__(self):
17
- self.model_id = "google/paligemma-3b-mix-448"
18
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- self.model = PaliGemmaForConditionalGeneration.from_pretrained(self.model_id).eval().to(self.device)
20
- self.processor = PaliGemmaProcessor.from_pretrained(self.model_id)
21
-
22
- @spaces.GPU
23
- def infer(self, image: PIL.Image.Image, text: str, max_new_tokens: int) -> str:
24
- inputs = self.processor(text=text, images=image, return_tensors="pt")
25
- inputs = {k: v.to(self.device) for k, v in inputs.items()} # Move inputs to the correct device
26
- with torch.inference_mode():
27
- generated_ids = self.model.generate(
28
- **inputs,
29
- max_new_tokens=max_new_tokens,
30
- do_sample=False
31
- )
32
- result = self.processor.batch_decode(generated_ids, skip_special_tokens=True)
33
- return result[0][len(text):].lstrip("\n")
34
-
35
- class VAEModel:
36
- def __init__(self, model_path: str):
37
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- self.params = self._get_params(model_path)
39
-
40
- def _get_params(self, checkpoint_path):
41
- """Converts PyTorch checkpoint to Flax params."""
42
- checkpoint = dict(np.load(checkpoint_path))
43
-
44
- def transp(kernel):
45
- return np.transpose(kernel, (2, 3, 1, 0))
46
-
47
- def conv(name):
48
- return {
49
- 'bias': checkpoint[name + '.bias'],
50
- 'kernel': transp(checkpoint[name + '.weight']),
51
- }
52
-
53
- def resblock(name):
54
- return {
55
- 'Conv_0': conv(name + '.0'),
56
- 'Conv_1': conv(name + '.2'),
57
- 'Conv_2': conv(name + '.4'),
58
- }
59
-
60
- return {
61
- '_embeddings': checkpoint['_vq_vae._embedding'],
62
- 'Conv_0': conv('decoder.0'),
63
- 'ResBlock_0': resblock('decoder.2.net'),
64
- 'ResBlock_1': resblock('decoder.3.net'),
65
- 'ConvTranspose_0': conv('decoder.4'),
66
- 'ConvTranspose_1': conv('decoder.6'),
67
- 'ConvTranspose_2': conv('decoder.8'),
68
- 'ConvTranspose_3': conv('decoder.10'),
69
- 'Conv_1': conv('decoder.12'),
70
- }
71
-
72
- def reconstruct_masks(self, codebook_indices):
73
- quantized = self._quantized_values_from_codebook_indices(codebook_indices)
74
- return self._decoder().apply({'params': self.params}, quantized)
75
-
76
- def _quantized_values_from_codebook_indices(self, codebook_indices):
77
- batch_size, num_tokens = codebook_indices.shape
78
- assert num_tokens == 16, codebook_indices.shape
79
- unused_num_embeddings, embedding_dim = self.params['_embeddings'].shape
80
-
81
- encodings = jnp.take(self.params['_embeddings'], codebook_indices.reshape((-1)), axis=0)
82
- encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
83
- return encodings
84
-
85
- @functools.cache
86
- def _decoder(self):
87
- class ResBlock(nn.Module):
88
- features: int
89
-
90
- @nn.compact
91
- def __call__(self, x):
92
- original_x = x
93
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
94
- x = nn.relu(x)
95
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
96
- x = nn.relu(x)
97
- x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
98
- return x + original_x
99
-
100
- class Decoder(nn.Module):
101
- """Upscales quantized vectors to mask."""
102
-
103
- @nn.compact
104
- def __call__(self, x):
105
- num_res_blocks = 2
106
- dim = 128
107
- num_upsample_layers = 4
108
-
109
- x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
110
- x = nn.relu(x)
111
-
112
- for _ in range(num_res_blocks):
113
- x = ResBlock(features=dim)(x)
114
-
115
- for _ in range(num_upsample_layers):
116
- x = nn.ConvTranspose(
117
- features=dim,
118
- kernel_size=(4, 4),
119
- strides=(2, 2),
120
- padding=2,
121
- transpose_kernel=True,
122
- )(x)
123
- x = nn.relu(x)
124
- dim //= 2
125
-
126
- x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
127
-
128
- return x
129
-
130
- return jax.jit(Decoder().apply, backend='cpu')