Spaces:
Runtime error
Runtime error
Li
commited on
Commit
·
d3fbc73
1
Parent(s):
f407227
“update”
Browse files
app.py
CHANGED
|
@@ -92,6 +92,9 @@ def generate(
|
|
| 92 |
all_ids = set(range(flamingo.lang_encoder.lm_head.out_features))
|
| 93 |
bad_words_ids = list(all_ids - set(loc_token_ids))
|
| 94 |
bad_words_ids = [[b] for b in bad_words_ids]
|
|
|
|
|
|
|
|
|
|
| 95 |
min_loc_token_id = min(loc_token_ids)
|
| 96 |
max_loc_token_id = max(loc_token_ids)
|
| 97 |
image_ori = image
|
|
@@ -103,9 +106,11 @@ def generate(
|
|
| 103 |
if idx == 1:
|
| 104 |
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}<|#loc#|>"]
|
| 105 |
bad_words_ids = None
|
|
|
|
| 106 |
else:
|
| 107 |
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
|
| 108 |
-
bad_words_ids =
|
|
|
|
| 109 |
encodings = tokenizer(
|
| 110 |
prompt,
|
| 111 |
padding="longest",
|
|
@@ -122,7 +127,7 @@ def generate(
|
|
| 122 |
model=flamingo,
|
| 123 |
batch_images=batch_images,
|
| 124 |
attention_mask=attention_mask,
|
| 125 |
-
max_generation_length=
|
| 126 |
min_generation_length=4,
|
| 127 |
num_beams=1,
|
| 128 |
length_penalty=1.0,
|
|
|
|
| 92 |
all_ids = set(range(flamingo.lang_encoder.lm_head.out_features))
|
| 93 |
bad_words_ids = list(all_ids - set(loc_token_ids))
|
| 94 |
bad_words_ids = [[b] for b in bad_words_ids]
|
| 95 |
+
loc_word_ids = list(set(loc_token_ids))
|
| 96 |
+
loc_word_ids = [[b] for b in loc_word_ids]
|
| 97 |
+
|
| 98 |
min_loc_token_id = min(loc_token_ids)
|
| 99 |
max_loc_token_id = max(loc_token_ids)
|
| 100 |
image_ori = image
|
|
|
|
| 106 |
if idx == 1:
|
| 107 |
prompt = [f"<|#image#|>{tokenizer.pad_token*vis_embed_size}<|#endofimage#|><|#obj#|>{text.rstrip('.')}<|#loc#|>"]
|
| 108 |
bad_words_ids = None
|
| 109 |
+
max_generation_length = 5
|
| 110 |
else:
|
| 111 |
prompt = [f"<|#image#|>{tokenizer.pad_token * vis_embed_size}<|#endofimage#|>{text.rstrip('.')}"]
|
| 112 |
+
bad_words_ids = loc_word_ids
|
| 113 |
+
max_generation_length = 100
|
| 114 |
encodings = tokenizer(
|
| 115 |
prompt,
|
| 116 |
padding="longest",
|
|
|
|
| 127 |
model=flamingo,
|
| 128 |
batch_images=batch_images,
|
| 129 |
attention_mask=attention_mask,
|
| 130 |
+
max_generation_length=max_generation_length,
|
| 131 |
min_generation_length=4,
|
| 132 |
num_beams=1,
|
| 133 |
length_penalty=1.0,
|