Spaces:
Runtime error
Runtime error
Commit
·
3872237
1
Parent(s):
e1777e5
reqs
Browse files
app.py
CHANGED
|
@@ -134,7 +134,7 @@ pali = PaliGemmaForConditionalGeneration.from_pretrained('google/paligemma-3b-pt
|
|
| 134 |
processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
|
| 135 |
|
| 136 |
|
| 137 |
-
|
| 138 |
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
| 139 |
inputs_embeds = pali.get_input_embeddings()(input_ids)
|
| 140 |
selected_image_feature = image_outputs.to(dtype).to(device)
|
|
@@ -148,7 +148,7 @@ def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None
|
|
| 148 |
return inputs_embeds
|
| 149 |
|
| 150 |
|
| 151 |
-
|
| 152 |
def generate_pali(user_emb):
|
| 153 |
prompt = 'caption en'
|
| 154 |
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
|
@@ -540,8 +540,7 @@ def encode_space(x):
|
|
| 540 |
im_emb, _ = pipe.encode_image(
|
| 541 |
image, DEVICE, 1, output_hidden_state
|
| 542 |
)
|
| 543 |
-
|
| 544 |
-
|
| 545 |
im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
|
| 546 |
im = torch.nn.functional.interpolate(im, (224, 224))
|
| 547 |
im = (im - .5) * 2
|
|
|
|
| 134 |
processor = AutoProcessor.from_pretrained('google/paligemma-3b-pt-224')
|
| 135 |
|
| 136 |
|
| 137 |
+
@spaces.GPU()
|
| 138 |
def to_wanted_embs(image_outputs, input_ids, attention_mask, cache_position=None):
|
| 139 |
inputs_embeds = pali.get_input_embeddings()(input_ids)
|
| 140 |
selected_image_feature = image_outputs.to(dtype).to(device)
|
|
|
|
| 148 |
return inputs_embeds
|
| 149 |
|
| 150 |
|
| 151 |
+
@spaces.GPU()
|
| 152 |
def generate_pali(user_emb):
|
| 153 |
prompt = 'caption en'
|
| 154 |
model_inputs = processor(text=prompt, images=torch.zeros(1, 3, 224, 224), return_tensors="pt")
|
|
|
|
| 540 |
im_emb, _ = pipe.encode_image(
|
| 541 |
image, DEVICE, 1, output_hidden_state
|
| 542 |
)
|
| 543 |
+
|
|
|
|
| 544 |
im = torchvision.transforms.ToTensor()(x).unsqueeze(0)
|
| 545 |
im = torch.nn.functional.interpolate(im, (224, 224))
|
| 546 |
im = (im - .5) * 2
|