Severian commited on
Commit
348afd0
·
verified ·
1 Parent(s): 36176e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -47
app.py CHANGED
@@ -156,40 +156,30 @@ clip_model.to("cuda")
156
 
157
  # Tokenizer
158
  print("Loading tokenizer")
159
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, use_fast=False)
160
  assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
161
 
162
  # LLM
163
  print("Loading LLM")
164
- if (CHECKPOINT_PATH / "text_model").exists:
165
- print("Loading VLM's custom text model")
166
- text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
167
- else:
168
- text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
169
-
170
  text_model.eval()
171
 
172
  # Image Adapter
173
  print("Loading image adapter")
174
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
175
- image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu", weights_only=True))
176
  image_adapter.eval()
177
  image_adapter.to("cuda")
178
 
179
 
180
  def preprocess_image(input_image: Image.Image) -> torch.Tensor:
181
- """
182
- Preprocess the input image for the CLIP model.
183
- """
184
  image = input_image.resize((384, 384), Image.LANCZOS)
185
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
186
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
187
  return pixel_values.to('cuda')
188
 
189
  def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
190
- """
191
- Generate a caption based on the image features and prompt.
192
- """
193
  prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
194
  prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
195
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
@@ -216,7 +206,7 @@ def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max
216
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
217
  generate_ids = generate_ids[:, :-1]
218
 
219
- return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0].strip()
220
 
221
  @spaces.GPU()
222
  @torch.no_grad()
@@ -266,18 +256,33 @@ def stream_chat(input_image: Image.Image, caption_type: str, caption_length: str
266
  # For debugging
267
  print(f"Prompt: {prompt_str}")
268
 
 
269
  pixel_values = preprocess_image(input_image)
270
 
 
271
  with torch.amp.autocast_mode.autocast('cuda', enabled=True):
272
  vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
273
  image_features = vision_outputs.hidden_states
274
  embedded_images = image_adapter(image_features)
275
  embedded_images = embedded_images.to('cuda')
276
 
277
- # Load the model from MODEL_PATH
278
- text_model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, device_map="auto", torch_dtype=torch.bfloat16)
279
- text_model.eval()
280
-
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
282
 
283
  return prompt_str, caption.strip()
@@ -437,8 +442,8 @@ def login(username, password):
437
  # Gradio interface
438
  with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True) as demo:
439
  with gr.Tab("Welcome"):
440
- with gr.Row():
441
- with gr.Column(scale=2):
442
  gr.Markdown(
443
  """
444
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/LVZnwLV43UUvKu3HORqSs.webp" alt="UDG" width="250" class="centered-image">
@@ -471,9 +476,9 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
471
  )
472
 
473
  with gr.Row():
474
- username = gr.Textbox(label="Username", placeholder="Enter your username", value="ugd")
475
  with gr.Row():
476
- password = gr.Textbox(label="Password", type="password", placeholder="Enter your password", value="ugd!")
477
  with gr.Row():
478
  login_button = gr.Button("Login", size="sm")
479
  login_message = gr.Markdown(visible=False)
@@ -485,7 +490,7 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
485
  # How to Use Caption Captain
486
 
487
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/Ce_Z478iOXljvpZ_Fr_Y7.png" alt="Captain" width="100" style="max-width: 100%; height: auto;">
488
-
489
  Hello, artist! Let's create amazing captions for your pictures. Here's a comprehensive guide:
490
 
491
  1. **Upload Your Image**: Choose a picture you want to caption and upload it.
@@ -553,34 +558,35 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
553
  value="long",
554
  )
555
 
556
- with gr.Accordion("Extra Options", open=True):
557
- extra_options = gr.CheckboxGroup(
558
- choices=[
559
- "If there is a person/character in the image you must refer to them as {name}.",
560
- "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
561
- "Include information about lighting.",
562
- "Include information about camera angle.",
563
- "Include information about whether there is a watermark or not.",
564
- "Include information about whether there are JPEG artifacts or not.",
565
- "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
566
- "Do NOT include anything sexual; keep it PG.",
567
- "Do NOT mention the image's resolution.",
568
- "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
569
- "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
570
- "Do NOT mention any text that is in the image.",
571
- "Specify the depth of field and whether the background is in focus or blurred.",
572
- "If applicable, mention the likely use of artificial or natural lighting sources.",
573
- "Do NOT use any ambiguous language.",
574
- "Include whether the image is sfw, suggestive, or nsfw.",
575
- "ONLY describe the most important elements of the image."
576
- ],
577
- )
 
578
 
579
  name_input = gr.Textbox(label="Person/Character Name (if applicable)")
580
  gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
581
 
582
  custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override all other settings)")
583
- gr.Markdown("**Note:** Caption Captain is not great at general instruction following and will not follow prompts outside its training data well. Use this feature with caution.")
584
 
585
  with gr.Column():
586
  error_message = gr.Markdown(visible=False)
@@ -646,6 +652,7 @@ with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True
646
  outputs=[caption_captain_tab, username, password, login_message]
647
  )
648
 
 
649
  password.submit(
650
  login,
651
  inputs=[username, password],
 
156
 
157
  # Tokenizer
158
  print("Loading tokenizer")
159
+ tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT_PATH / "text_model", use_fast=True)
160
  assert isinstance(tokenizer, PreTrainedTokenizer) or isinstance(tokenizer, PreTrainedTokenizerFast), f"Tokenizer is of type {type(tokenizer)}"
161
 
162
  # LLM
163
  print("Loading LLM")
164
+ print("Loading VLM's custom text model")
165
+ text_model = AutoModelForCausalLM.from_pretrained(CHECKPOINT_PATH / "text_model", device_map=0, torch_dtype=torch.bfloat16)
 
 
 
 
166
  text_model.eval()
167
 
168
  # Image Adapter
169
  print("Loading image adapter")
170
  image_adapter = ImageAdapter(clip_model.config.hidden_size, text_model.config.hidden_size, False, False, 38, False)
171
+ image_adapter.load_state_dict(torch.load(CHECKPOINT_PATH / "image_adapter.pt", map_location="cpu"))
172
  image_adapter.eval()
173
  image_adapter.to("cuda")
174
 
175
 
176
  def preprocess_image(input_image: Image.Image) -> torch.Tensor:
 
 
 
177
  image = input_image.resize((384, 384), Image.LANCZOS)
178
  pixel_values = TVF.pil_to_tensor(image).unsqueeze(0) / 255.0
179
  pixel_values = TVF.normalize(pixel_values, [0.5], [0.5])
180
  return pixel_values.to('cuda')
181
 
182
  def generate_caption(text_model, tokenizer, image_features, prompt_str: str, max_new_tokens: int = 300) -> str:
 
 
 
183
  prompt = tokenizer.encode(prompt_str, return_tensors='pt', padding=False, truncation=False, add_special_tokens=False)
184
  prompt_embeds = text_model.model.embed_tokens(prompt.to('cuda'))
185
  embedded_bos = text_model.model.embed_tokens(torch.tensor([[tokenizer.bos_token_id]], device=text_model.device, dtype=torch.int64))
 
206
  if generate_ids[0][-1] == tokenizer.eos_token_id or generate_ids[0][-1] == tokenizer.convert_tokens_to_ids("<|eot_id|>"):
207
  generate_ids = generate_ids[:, :-1]
208
 
209
+ return tokenizer.batch_decode(generate_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)[0]
210
 
211
  @spaces.GPU()
212
  @torch.no_grad()
 
256
  # For debugging
257
  print(f"Prompt: {prompt_str}")
258
 
259
+ # Preprocess image
260
  pixel_values = preprocess_image(input_image)
261
 
262
+ # Embed image
263
  with torch.amp.autocast_mode.autocast('cuda', enabled=True):
264
  vision_outputs = clip_model(pixel_values=pixel_values, output_hidden_states=True)
265
  image_features = vision_outputs.hidden_states
266
  embedded_images = image_adapter(image_features)
267
  embedded_images = embedded_images.to('cuda')
268
 
269
+ # Build the conversation
270
+ convo = [
271
+ {
272
+ "role": "system",
273
+ "content": "You are a helpful image captioner.",
274
+ },
275
+ {
276
+ "role": "user",
277
+ "content": prompt_str,
278
+ },
279
+ ]
280
+
281
+ # Format the conversation
282
+ convo_string = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=True)
283
+ assert isinstance(convo_string, str)
284
+
285
+ # Generate caption
286
  caption = generate_caption(text_model, tokenizer, embedded_images, prompt_str)
287
 
288
  return prompt_str, caption.strip()
 
442
  # Gradio interface
443
  with gr.Blocks(theme="Hev832/Applio", css=css, fill_width=True, fill_height=True) as demo:
444
  with gr.Tab("Welcome"):
445
+ with gr.Row(elem_classes="welcome-tab"):
446
+ with gr.Column(scale=2, elem_classes="welcome-content"):
447
  gr.Markdown(
448
  """
449
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/LVZnwLV43UUvKu3HORqSs.webp" alt="UDG" width="250" class="centered-image">
 
476
  )
477
 
478
  with gr.Row():
479
+ username = gr.Textbox(label="Username", placeholder="Enter your username")
480
  with gr.Row():
481
+ password = gr.Textbox(label="Password", type="password", placeholder="Enter your password")
482
  with gr.Row():
483
  login_button = gr.Button("Login", size="sm")
484
  login_message = gr.Markdown(visible=False)
 
490
  # How to Use Caption Captain
491
 
492
  <img src="https://cdn-uploads.huggingface.co/production/uploads/64740cf7485a7c8e1bd51ac9/Ce_Z478iOXljvpZ_Fr_Y7.png" alt="Captain" width="100" style="max-width: 100%; height: auto;">
493
+
494
  Hello, artist! Let's create amazing captions for your pictures. Here's a comprehensive guide:
495
 
496
  1. **Upload Your Image**: Choose a picture you want to caption and upload it.
 
558
  value="long",
559
  )
560
 
561
+ with gr.Accordion("Extra Options", open=True):
562
+ extra_options = gr.CheckboxGroup(
563
+ choices=[
564
+ "If there is a person/character in the image you must refer to them as {name}.",
565
+ "Do NOT include information about people/characters that cannot be changed (like ethnicity, gender, etc), but do still include changeable attributes (like hair style).",
566
+ "Include information about lighting.",
567
+ "Include information about camera angle.",
568
+ "Include information about whether there is a watermark or not.",
569
+ "Include information about whether there are JPEG artifacts or not.",
570
+ "If it is a photo you MUST include information about what camera was likely used and details such as aperture, shutter speed, ISO, etc.",
571
+ "Do NOT include anything sexual; keep it PG.",
572
+ "Do NOT mention the image's resolution.",
573
+ "You MUST include information about the subjective aesthetic quality of the image from low to very high.",
574
+ "Include information on the image's composition style, such as leading lines, rule of thirds, or symmetry.",
575
+ "Do NOT mention any text that is in the image.",
576
+ "Specify the depth of field and whether the background is in focus or blurred.",
577
+ "If applicable, mention the likely use of artificial or natural lighting sources.",
578
+ "Do NOT use any ambiguous language.",
579
+ "Include whether the image is sfw, suggestive, or nsfw.",
580
+ "ONLY describe the most important elements of the image."
581
+ ],
582
+ label="Select Extra Options"
583
+ )
584
 
585
  name_input = gr.Textbox(label="Person/Character Name (if applicable)")
586
  gr.Markdown("**Note:** Name input is only used if an Extra Option is selected that requires it.")
587
 
588
  custom_prompt = gr.Textbox(label="Custom Prompt (optional, will override all other settings)")
589
+ gr.Markdown("**Note:** Alpha Two is not a general instruction follower and will not follow prompts outside its training data well. Use this feature with caution.")
590
 
591
  with gr.Column():
592
  error_message = gr.Markdown(visible=False)
 
652
  outputs=[caption_captain_tab, username, password, login_message]
653
  )
654
 
655
+ # Add this new event listener for the password field
656
  password.submit(
657
  login,
658
  inputs=[username, password],