Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -156,40 +156,30 @@ clip_model.to("cuda")
|
|
156 |
|
157 |
# Tokenizer
|
158 |
print("Loading tokenizer")
|
159 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
160 |
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
|
161 |
|
162 |
# LLM
|
163 |
print("Loading LLM")
|
164 |
-
|
165 |
-
|
166 |
-
text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
|
167 |
-
else:
|
168 |
-
text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
|
169 |
-
|
170 |
text_model.eval()
|
171 |
|
172 |
# Image Adapter
|
173 |
print("Loading image adapter")
|
174 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
|
175 |
-
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"
|
176 |
image_adapter.eval()
|
177 |
image_adapter.to("cuda")
|
178 |
|
179 |
|
180 |
def preprocess_image(input_image: Image.Image) -> torch.Tensor:
|
181 |
-
"""
|
182 |
-
Preprocess the input image for the CLIP model.
|
183 |
-
"""
|
184 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
185 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
186 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
187 |
return pixel_values.to('cuda')
|
188 |
|
189 |
def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
|
190 |
-
"""
|
191 |
-
Generate a caption based on the image features and prompt.
|
192 |
-
"""
|
193 |
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
194 |
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
195 |
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
@@ -216,7 +206,7 @@ def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max
|
|
216 |
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
217 |
generate_ids = generate_ids[:, :-1]
|
218 |
|
219 |
-
return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
220 |
|
221 |
@spaces.GPU()
|
222 |
@torch.no_grad()
|
@@ -266,18 +256,33 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
|
|
266 |
# For debugging
|
267 |
print(f"Prompt: {prompt_str}")
|
268 |
|
|
|
269 |
pixel_values = preprocess_image(input_image)
|
270 |
|
|
|
271 |
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
272 |
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
273 |
image_features = vision_outputs.hidden_states
|
274 |
embedded_images = image_adapter(image_features)
|
275 |
embedded_images = embedded_images.to('cuda')
|
276 |
|
277 |
-
#
|
278 |
-
|
279 |
-
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
|
282 |
|
283 |
return prompt_str, caption.strip()
|
@@ -437,8 +442,8 @@ def login(username, password):
|
|
437 |
# Gradio interface
|
438 |
with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True) as demo:
|
439 |
with gr.Tab("Welcome"):
|
440 |
-
with gr.Row():
|
441 |
-
with gr.Column(scale=2):
|
442 |
gr.Markdown(
|
443 |
"""
|
444 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/LVZnwLV43UUvKu3HORqSs.webp" alt="UDG" width="250" class="centered-image">
|
@@ -471,9 +476,9 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
471 |
)
|
472 |
|
473 |
with gr.Row():
|
474 |
-
username = gr.Textbox(label="Username", placeholder="Enter your username"
|
475 |
with gr.Row():
|
476 |
-
password = gr.Textbox(label="Password", type="password", placeholder="Enter your password"
|
477 |
with gr.Row():
|
478 |
login_button = gr.Button("Login", size="sm")
|
479 |
login_message = gr.Markdown(visible=False)
|
@@ -485,7 +490,7 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
485 |
# How to Use Caption Captain
|
486 |
|
487 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/Ce_Z478iOXljvpZ_Fr_Y7.png" alt="Captain" width="100" style="max-width: 100%; height: auto;">
|
488 |
-
|
489 |
Hello, artist! Let's create amazing captions for your pictures. Here's a comprehensive guide:
|
490 |
|
491 |
1. **Upload Your Image**: Choose a picture you want to caption and upload it.
|
@@ -553,34 +558,35 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
553 |
value="long",
|
554 |
)
|
555 |
|
556 |
-
|
557 |
-
|
558 |
-
|
559 |
-
|
560 |
-
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
|
565 |
-
|
566 |
-
|
567 |
-
|
568 |
-
|
569 |
-
|
570 |
-
|
571 |
-
|
572 |
-
|
573 |
-
|
574 |
-
|
575 |
-
|
576 |
-
|
577 |
-
|
|
|
578 |
|
579 |
name_input = gr.Textbox(label="Person/Character Name (if applicable)")
|
580 |
gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
|
581 |
|
582 |
custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override all other settings)")
|
583 |
-
gr.Markdown("**Note:**
|
584 |
|
585 |
with gr.Column():
|
586 |
error_message = gr.Markdown(visible=False)
|
@@ -646,6 +652,7 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
|
|
646 |
outputs=[caption_captain_tab, username, password, login_message]
|
647 |
)
|
648 |
|
|
|
649 |
password.submit(
|
650 |
login,
|
651 |
inputs=[username, password],
|
|
|
156 |
|
157 |
# Tokenizer
|
158 |
print("Loading tokenizer")
|
159 |
+
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH / "text_model", use_fast=True)
|
160 |
assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
|
161 |
|
162 |
# LLM
|
163 |
print("Loading LLM")
|
164 |
+
print("Loading VLM's custom text model")
|
165 |
+
text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
166 |
text_model.eval()
|
167 |
|
168 |
# Image Adapter
|
169 |
print("Loading image adapter")
|
170 |
image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
|
171 |
+
image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
|
172 |
image_adapter.eval()
|
173 |
image_adapter.to("cuda")
|
174 |
|
175 |
|
176 |
def preprocess_image(input_image: Image.Image) -> torch.Tensor:
|
|
|
|
|
|
|
177 |
image = input_image.resize((384, 384), Image.LANCZOS)
|
178 |
pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
|
179 |
pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
|
180 |
return pixel_values.to('cuda')
|
181 |
|
182 |
def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
|
|
|
|
|
|
|
183 |
prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
|
184 |
prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
|
185 |
embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
|
|
|
206 |
if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
|
207 |
generate_ids = generate_ids[:, :-1]
|
208 |
|
209 |
+
return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
|
210 |
|
211 |
@spaces.GPU()
|
212 |
@torch.no_grad()
|
|
|
256 |
# For debugging
|
257 |
print(f"Prompt: {prompt_str}")
|
258 |
|
259 |
+
# Preprocess image
|
260 |
pixel_values = preprocess_image(input_image)
|
261 |
|
262 |
+
# Embed image
|
263 |
with torch.amp.autocast_mode.autocast('cuda', enabled=True):
|
264 |
vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
|
265 |
image_features = vision_outputs.hidden_states
|
266 |
embedded_images = image_adapter(image_features)
|
267 |
embedded_images = embedded_images.to('cuda')
|
268 |
|
269 |
+
# Build the conversation
|
270 |
+
convo = [
|
271 |
+
{
|
272 |
+
"role": "system",
|
273 |
+
"content": "You are a helpful image captioner.",
|
274 |
+
},
|
275 |
+
{
|
276 |
+
"role": "user",
|
277 |
+
"content": prompt_str,
|
278 |
+
},
|
279 |
+
]
|
280 |
+
|
281 |
+
# Format the conversation
|
282 |
+
convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
|
283 |
+
assert isinstance(convo_string, str)
|
284 |
+
|
285 |
+
# Generate caption
|
286 |
caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
|
287 |
|
288 |
return prompt_str, caption.strip()
|
|
|
442 |
# Gradio interface
|
443 |
with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True) as demo:
|
444 |
with gr.Tab("Welcome"):
|
445 |
+
with gr.Row(elem_classes="welcome-tab"):
|
446 |
+
with gr.Column(scale=2, elem_classes="welcome-content"):
|
447 |
gr.Markdown(
|
448 |
"""
|
449 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/LVZnwLV43UUvKu3HORqSs.webp" alt="UDG" width="250" class="centered-image">
|
|
|
476 |
)
|
477 |
|
478 |
with gr.Row():
|
479 |
+
username = gr.Textbox(label="Username", placeholder="Enter your username")
|
480 |
with gr.Row():
|
481 |
+
password = gr.Textbox(label="Password", type="password", placeholder="Enter your password")
|
482 |
with gr.Row():
|
483 |
login_button = gr.Button("Login", size="sm")
|
484 |
login_message = gr.Markdown(visible=False)
|
|
|
490 |
# How to Use Caption Captain
|
491 |
|
492 |
<img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/Ce_Z478iOXljvpZ_Fr_Y7.png" alt="Captain" width="100" style="max-width: 100%; height: auto;">
|
493 |
+
|
494 |
Hello, artist! Let's create amazing captions for your pictures. Here's a comprehensive guide:
|
495 |
|
496 |
1. **Upload Your Image**: Choose a picture you want to caption and upload it.
|
|
|
558 |
value="long",
|
559 |
)
|
560 |
|
561 |
+
with gr.Accordion("Extra Options", open=True):
|
562 |
+
extra_options = gr.CheckboxGroup(
|
563 |
+
choices=[
|
564 |
+
"If there is a person/character in the image you must refer to them as {name}.",
|
565 |
+
"Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
|
566 |
+
"Include information about lighting.",
|
567 |
+
"Include information about camera angle.",
|
568 |
+
"Include information about whether there is a watermark or not.",
|
569 |
+
"Include information about whether there are JPEG artifacts or not.",
|
570 |
+
"If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
|
571 |
+
"Do NOT include anything sexual; keep it PG.",
|
572 |
+
"Do NOT mention the image's resolution.",
|
573 |
+
"You MUST include information about the subjective aesthetic quality of the image from low to very high.",
|
574 |
+
"Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
|
575 |
+
"Do NOT mention any text that is in the image.",
|
576 |
+
"Specify the depth of field and whether the background is in focus or blurred.",
|
577 |
+
"If applicable, mention the likely use of artificial or natural lighting sources.",
|
578 |
+
"Do NOT use any ambiguous language.",
|
579 |
+
"Include whether the image is sfw, suggestive, or nsfw.",
|
580 |
+
"ONLY describe the most important elements of the image."
|
581 |
+
],
|
582 |
+
label="Select Extra Options"
|
583 |
+
)
|
584 |
|
585 |
name_input = gr.Textbox(label="Person/Character Name (if applicable)")
|
586 |
gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
|
587 |
|
588 |
custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override all other settings)")
|
589 |
+
gr.Markdown("**Note:** Alpha Two is not a general instruction follower and will not follow prompts outside its training data well. Use this feature with caution.")
|
590 |
|
591 |
with gr.Column():
|
592 |
error_message = gr.Markdown(visible=False)
|
|
|
652 |
outputs=[caption_captain_tab, username, password, login_message]
|
653 |
)
|
654 |
|
655 |
+
# Add this new event listener for the password field
|
656 |
password.submit(
|
657 |
login,
|
658 |
inputs=[username, password],
|