preston-cell's picture
Update app.py
9d274d8 verified
raw
history blame
3.54 kB
import gradio as gr
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
import torch
import numpy as np
from datasets import load_dataset
from PIL import Image
# 1) IMAGE CAPTIONING MODEL
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
# 2) OCR MODEL (Florence-2)
ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
ocr_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large",
torch_dtype=ocr_dtype,
trust_remote_code=True
).to(ocr_device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# 3) QUESTION-ANSWERING MODEL
qa_model = pipeline(
"question-answering",
model="timpal0l/mdeberta-v3-base-squad2"
)
# 4) TEXT-TO-SPEECH MODEL
tts_pipeline = pipeline("text-to-speech", model="microsoft/speecht5_tts")
# Load speaker embedding
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
def process_image(image):
try:
# 1) Generate caption from the image
caption = caption_model(image)[0]['generated_text']
# 2) Extract text from the image using Florence-2
inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
generated_ids = ocr_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=4096,
num_beams=3,
do_sample=False
)
extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# 3) Use QA model to derive context from caption + extracted text
# We treat the "context" string as the knowledge base and ask a question about it
question = "What is the context of this image?"
combined_context = f"Caption: {caption}\nExtracted Text: {extracted_text}"
qa_result = qa_model(question=question, context=combined_context)
# The QA model returns an extracted "answer" from the combined context
# If the model can't find a direct span, it may return an empty string or a short phrase
final_context = qa_result["answer"]
# 4) Convert the final context to speech
speech_data = tts_pipeline(
final_context,
forward_params={"speaker_embeddings": speaker_embedding}
)
# Prepare audio data for Gradio
audio = np.array(speech_data["audio"])
rate = speech_data["sampling_rate"]
# Return audio, caption, extracted text, and final context
return (rate, audio), caption, extracted_text, final_context
except Exception as e:
return None, f"Error: {str(e)}", "", ""
# Gradio Interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type='pil', label="Upload an Image"),
outputs=[
gr.Audio(label="Generated Audio"),
gr.Textbox(label="Generated Caption"),
gr.Textbox(label="Extracted Text (OCR)"),
gr.Textbox(label="QA-derived Context")
],
title="Contextual Image QA with SpeechT5",
description=(
"1) Generate a caption via BLIP. "
"2) Extract text using Florence-2. "
"3) Use QA with mDeBERTa to find a 'context' from caption + text. "
"4) Convert it to audio via SpeechT5."
),
)
iface.launch()