caption-match / app.py
iamrobotbear's picture
attempting to get this working at all again.
56786fe
raw
history blame
3.02 kB
import gradio as gr
import torch
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load model and preprocessors for Image-Text Matching (LAVIS)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
# Load tokenizer and model for Image Captioning (TextCaps)
tokenizer_caption = AutoTokenizer.from_pretrained("microsoft/git-large-r-textcaps")
model_caption = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
# List of statements for Image-Text Matching
statements = [
"cartoon, figurine, or toy",
"appears to be for children",
"includes children",
"is sexual",
"depicts a child or portrays objects, images, or cartoon figures that primarily appeal to persons below the legal purchase age",
"uses the name of or depicts Santa Claus",
'promotes alcohol use as a "rite of passage" to adulthood',
]
txts = [text_processors["eval"](statement) for statement in statements]
# Function to compute Image-Text Matching (ITM) scores for all statements
def compute_itm_scores(image):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
results = []
for i, statement in enumerate(statements):
txt = txts[i]
itm_output = model_itm({"image": img, "text_input": txt}, match_head="itm")
itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
score = itm_scores[:, 1].item()
result_text = f'The image and "{statement}" are matched with a probability of {score:.3%}'
results.append(result_text)
output = "\n".join(results)
return output
# Function to generate image captions using TextCaps
def generate_image_captions():
prompt = "A photo of"
inputs = tokenizer_caption(prompt, return_tensors="pt", padding=True, truncation=True)
outputs = model_caption.generate(**inputs)
caption = tokenizer_caption.decode(outputs[0], skip_special_tokens=True)
return prompt + " " + caption
# Main function to perform image captioning and image-text matching
def process_images_and_statements(image):
# Generate image captions using TextCaps
captions = generate_image_captions()
# Compute ITM scores for predefined statements using LAVIS
itm_scores = compute_itm_scores(image)
# Combine image captions and ITM scores into the output
output = "Image Captions:\n" + captions + "\n\nITM Scores:\n" + itm_scores
return output
# Gradio interface
image_input = gr.inputs.Image()
output = gr.outputs.Textbox(label="Results")
iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
iface.launch()