Spaces:
Runtime error
Runtime error
File size: 3,251 Bytes
c617ba2 56786fe c617ba2 ea1e3a1 c617ba2 6c00d80 ea1e3a1 c617ba2 d72dfa9 c617ba2 ea1e3a1 56786fe ea1e3a1 56786fe ea1e3a1 c617ba2 ea1e3a1 c617ba2 ea1e3a1 c617ba2 ea1e3a1 c617ba2 afa4c81 c617ba2 ca6ee07 ea1e3a1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import gradio as gr
import torch
from PIL import Image
import pandas as pd
from lavis.models import load_model_and_preprocess
from lavis.processors import load_processor
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
# Load model and preprocessors for Image-Text Matching (LAVIS)
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
model_itm, vis_processors, text_processors = load_model_and_preprocess("blip2_image_text_matching", "pretrain", device=device, is_eval=True)
# Load tokenizer and model for Image Captioning (TextCaps)
git_processor_large_textcaps = AutoProcessor.from_pretrained("microsoft/git-large-r-textcaps")
git_model_large_textcaps = AutoModelForCausalLM.from_pretrained("microsoft/git-large-r-textcaps")
# List of statements for Image-Text Matching
statements = [
"cartoon, figurine, or toy",
"appears to be for children",
"includes children",
"is sexual",
"depicts a child or portrays objects, images, or cartoon figures that primarily appeal to persons below the legal purchase age",
"uses the name of or depicts Santa Claus",
'promotes alcohol use as a "rite of passage" to adulthood',
]
# Function to compute ITM scores for the combined text input (caption + statement)
def compute_itm_score(image, combined_text):
pil_image = Image.fromarray(image.astype('uint8'), 'RGB')
img = vis_processors["eval"](pil_image.convert("RGB")).unsqueeze(0).to(device)
# Pass the combined_text string directly to model_itm
itm_output = model_itm({"image": img, "text_input": combined_text}, match_head="itm")
itm_scores = torch.nn.functional.softmax(itm_output, dim=1)
score = itm_scores[:, 1].item()
return score
def generate_caption(processor, model, image):
inputs = processor(images=image, return_tensors="pt").to(device)
generated_ids = model.generate(pixel_values=inputs.pixel_values, max_length=50)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_caption
# Main function to perform image captioning and image-text matching
def process_images_and_statements(image):
# Generate image caption for the uploaded image using git-large-r-textcaps
caption = generate_caption(git_processor_large_textcaps, git_model_large_textcaps, image)
# Initialize an empty list to store the results
results = []
# Loop through each predefined statement
for statement in statements:
# Concatenate the caption with the statement
combined_text = caption + " " + statement
# Compute ITM score for the combined text and the image
itm_score = compute_itm_score(image, combined_text)
# Store the result
result_text = f'The image and "{combined_text}" are matched with a probability of {itm_score:.3%}'
results.append(result_text)
# Combine the results and return them
output = "\n".join(results)
return output
# Gradio interface
image_input = gr.inputs.Image()
output = gr.outputs.Textbox(label="Results")
iface = gr.Interface(fn=process_images_and_statements, inputs=image_input, outputs=output, title="Image Captioning and Image-Text Matching")
iface.launch() |