caption-match / app.py
iamrobotbear's picture
Bring up to date with working github copy
ea1e3a1
raw
history blame
3.25 kB
import gradio as gr
import torch
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
# Load model and preprocessors for Image-Text Matching (LAVIS)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
# Load tokenizer and model for Image Captioning (TextCaps)
git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
# List of statements for Image-Text Matching
statements = [
"cartoon, figurine, or toy",
"appears to be for children",
"includes children",
"is sexual",
"depicts a child or portrays objects, images, or cartoon figures that primarily appeal to persons below the legal purchase age",
"uses the name of or depicts Santa Claus",
'promotes alcohol use as a "rite of passage" to adulthood',
]
# Function to compute ITM scores for the combined text input (caption + statement)
def compute_itm_score(image, combined_text):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
# Pass the combined_text string directly to model_itm
itm_output = model_itm({"image": img, "text_input": combined_text}, match_head="itm")
itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
score = itm_scores[:, 1].item()
return score
def generate_caption(processor, model, image):
inputs = processor(images=image, return_tensors="pt").to(device)
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
# Main function to perform image captioning and image-text matching
def process_images_and_statements(image):
# Generate image caption for the uploaded image using git-large-r-textcaps
caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)
# Initialize an empty list to store the results
results = []
# Loop through each predefined statement
for statement in statements:
# Concatenate the caption with the statement
combined_text = caption + " " + statement
# Compute ITM score for the combined text and the image
itm_score = compute_itm_score(image, combined_text)
# Store the result
result_text = f'The image and "{combined_text}" are matched with a probability of {itm_score:.3%}'
results.append(result_text)
# Combine the results and return them
output = "\n".join(results)
return output
# Gradio interface
image_input = gr.inputs.Image()
output = gr.outputs.Textbox(label="Results")
iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
iface.launch()