Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -182,25 +182,35 @@ def preprocess_image(input_image: Image.Image) -> torch.Tensor:
|
|
| 182 |
def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
|
| 183 |
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
| 184 |
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
| 185 |
-
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
| 186 |
-
eot_embed = image_adapter.get_eot_embedding().unsqueeze(0).to(dtype=text_model.dtype)
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
input_ids = torch.cat([
|
| 196 |
-
|
| 197 |
torch.zeros((1, image_features.shape[1]), dtype=torch.long),
|
| 198 |
-
|
| 199 |
-
torch.tensor([[tokenizer.convert_tokens_to_ids("<|eot_id|>")]], dtype=torch.long),
|
| 200 |
], dim=1).to('cuda')
|
| 201 |
attention_mask = torch.ones_like(input_ids)
|
| 202 |
|
| 203 |
-
generate_ids = text_model.generate(input_ids, inputs_embeds=
|
| 204 |
|
| 205 |
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
| 206 |
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
|
@@ -476,9 +486,9 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
| 476 |
)
|
| 477 |
|
| 478 |
with gr.Row():
|
| 479 |
-
username = gr.Textbox(label="Username", placeholder="Enter your username")
|
| 480 |
with gr.Row():
|
| 481 |
-
password = gr.Textbox(label="Password", type="password", placeholder="Enter your password")
|
| 482 |
with gr.Row():
|
| 483 |
login_button = gr.Button("Login", size="sm")
|
| 484 |
login_message = gr.Markdown(visible=False)
|
|
@@ -558,29 +568,29 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
| 558 |
value="long",
|
| 559 |
)
|
| 560 |
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
|
| 570 |
-
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
|
| 574 |
-
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
|
| 578 |
-
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
|
| 585 |
name_input = gr.Textbox(label="Person/Character Name (if applicable)")
|
| 586 |
gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
|
|
|
|
| 182 |
def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
|
| 183 |
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
| 184 |
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
convo = [
|
| 187 |
+
{"role": "system", "content": "You are a helpful image captioner."},
|
| 188 |
+
{"role": "user", "content": prompt_str},
|
| 189 |
+
]
|
| 190 |
+
convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
| 191 |
+
convo_tokens = tokenizer.encode(convo_string, return_tensors="pt", add_special_tokens=False, truncation=False)
|
| 192 |
+
convo_tokens = convo_tokens.squeeze(0)
|
| 193 |
+
|
| 194 |
+
eot_id_indices = (convo_tokens == tokenizer.convert_tokens_to_ids("<|eot_id|>")).nonzero(as_tuple=True)[0].tolist()
|
| 195 |
+
assert len(eot_id_indices) == 2, f"Expected 2 <|eot_id|> tokens, got {len(eot_id_indices)}"
|
| 196 |
+
preamble_len = eot_id_indices[1] - prompt.shape[1]
|
| 197 |
+
|
| 198 |
+
convo_embeds = text_model.model.embed_tokens(convo_tokens.unsqueeze(0).to('cuda'))
|
| 199 |
+
|
| 200 |
+
input_embeds = torch.cat([
|
| 201 |
+
convo_embeds[:, :preamble_len],
|
| 202 |
+
image_features.to(dtype=convo_embeds.dtype),
|
| 203 |
+
convo_embeds[:, preamble_len:],
|
| 204 |
+
], dim=1).to('cuda')
|
| 205 |
|
| 206 |
input_ids = torch.cat([
|
| 207 |
+
convo_tokens[:preamble_len].unsqueeze(0),
|
| 208 |
torch.zeros((1, image_features.shape[1]), dtype=torch.long),
|
| 209 |
+
convo_tokens[preamble_len:].unsqueeze(0),
|
|
|
|
| 210 |
], dim=1).to('cuda')
|
| 211 |
attention_mask = torch.ones_like(input_ids)
|
| 212 |
|
| 213 |
+
generate_ids = text_model.generate(input_ids, inputs_embeds=input_embeds, attention_mask=attention_mask, max_new_tokens=max_new_tokens, do_sample=True, suppress_tokens=None)
|
| 214 |
|
| 215 |
generate_ids = generate_ids[:, input_ids.shape[1]:]
|
| 216 |
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
|
|
|
| 486 |
)
|
| 487 |
|
| 488 |
with gr.Row():
|
| 489 |
+
username = gr.Textbox(label="Username", placeholder="Enter your username", value="ugd")
|
| 490 |
with gr.Row():
|
| 491 |
+
password = gr.Textbox(label="Password", type="password", placeholder="Enter your password", value="ugd!")
|
| 492 |
with gr.Row():
|
| 493 |
login_button = gr.Button("Login", size="sm")
|
| 494 |
login_message = gr.Markdown(visible=False)
|
|
|
|
| 568 |
value="long",
|
| 569 |
)
|
| 570 |
|
| 571 |
+
with gr.Accordion("Extra Options", open=True):
|
| 572 |
+
extra_options = gr.CheckboxGroup(
|
| 573 |
+
choices=[
|
| 574 |
+
"If there is a person/character in the image you must refer to them as {name}.",
|
| 575 |
+
"Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
|
| 576 |
+
"Include information about lighting.",
|
| 577 |
+
"Include information about camera angle.",
|
| 578 |
+
"Include information about whether there is a watermark or not.",
|
| 579 |
+
"Include information about whether there are JPEG artifacts or not.",
|
| 580 |
+
"If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
|
| 581 |
+
"Do NOT include anything sexual; keep it PG.",
|
| 582 |
+
"Do NOT mention the image's resolution.",
|
| 583 |
+
"You MUST include information about the subjective aesthetic quality of the image from low to very high.",
|
| 584 |
+
"Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
|
| 585 |
+
"Do NOT mention any text that is in the image.",
|
| 586 |
+
"Specify the depth of field and whether the background is in focus or blurred.",
|
| 587 |
+
"If applicable, mention the likely use of artificial or natural lighting sources.",
|
| 588 |
+
"Do NOT use any ambiguous language.",
|
| 589 |
+
"Include whether the image is sfw, suggestive, or nsfw.",
|
| 590 |
+
"ONLY describe the most important elements of the image."
|
| 591 |
+
],
|
| 592 |
+
label="Select Extra Options"
|
| 593 |
+
)
|
| 594 |
|
| 595 |
name_input = gr.Textbox(label="Person/Character Name (if applicable)")
|
| 596 |
gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
|