vqa / app.py
sahalhes's picture
k
2b4a20c
raw
history blame contribute delete
991 Bytes
import gradio as gr
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForQuestionAnswering
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
device = torch.device("cpu")
model.to(device)
def answer_question(image: Image.Image, question: str) -> str:
inputs = processor(image.convert("RGB"), question, return_tensors="pt").to(device)
with torch.no_grad():
output = model.generate(**inputs)
return processor.decode(output[0], skip_special_tokens=True).strip()
# Gradio interface
demo = gr.Interface(
fn=answer_question,
inputs=[
gr.Image(type="pil", label="Upload an Image"),
gr.Textbox(label="Ask a Question About the Image")
],
outputs=gr.Textbox(label="Answer"),
title="Visual Question Answering",
description="Ask a question about an image"
)
if __name__ == "__main__":
demo.launch()