|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
from transformers import BlipProcessor, BlipForQuestionAnswering |
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") |
|
model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base") |
|
|
|
device = torch.device("cpu") |
|
model.to(device) |
|
|
|
def answer_question(image: Image.Image, question: str) -> str: |
|
inputs = processor(image.convert("RGB"), question, return_tensors="pt").to(device) |
|
|
|
with torch.no_grad(): |
|
output = model.generate(**inputs) |
|
|
|
return processor.decode(output[0], skip_special_tokens=True).strip() |
|
|
|
|
|
demo = gr.Interface( |
|
fn=answer_question, |
|
inputs=[ |
|
gr.Image(type="pil", label="Upload an Image"), |
|
gr.Textbox(label="Ask a Question About the Image") |
|
], |
|
outputs=gr.Textbox(label="Answer"), |
|
title="Visual Question Answering", |
|
description="Ask a question about an image" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|