Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM | |
import torch | |
import numpy as np | |
from datasets import load_dataset | |
from PIL import Image | |
# 1) IMAGE CAPTIONING MODEL | |
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") | |
# 2) OCR MODEL (Florence-2) | |
ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
ocr_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-large", | |
torch_dtype=ocr_dtype, | |
trust_remote_code=True | |
).to(ocr_device) | |
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) | |
# 3) QUESTION-ANSWERING MODEL | |
qa_model = pipeline( | |
"question-answering", | |
model="timpal0l/mdeberta-v3-base-squad2" | |
) | |
# 4) TEXT-TO-SPEECH MODEL | |
tts_pipeline = pipeline("text-to-speech", model="microsoft/speecht5_tts") | |
# Load speaker embedding | |
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0) | |
def process_image(image): | |
try: | |
# 1) Generate caption from the image | |
caption = caption_model(image)[0]['generated_text'] | |
# 2) Extract text from the image using Florence-2 | |
inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype) | |
generated_ids = ocr_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=4096, | |
num_beams=3, | |
do_sample=False | |
) | |
extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# 3) Use QA model to derive context from caption + extracted text | |
# We treat the "context" string as the knowledge base and ask a question about it | |
question = "What is the context of this image?" | |
combined_context = f"Caption: {caption}\nExtracted Text: {extracted_text}" | |
qa_result = qa_model(question=question, context=combined_context) | |
# The QA model returns an extracted "answer" from the combined context | |
# If the model can't find a direct span, it may return an empty string or a short phrase | |
final_context = qa_result["answer"] | |
# 4) Convert the final context to speech | |
speech_data = tts_pipeline( | |
final_context, | |
forward_params={"speaker_embeddings": speaker_embedding} | |
) | |
# Prepare audio data for Gradio | |
audio = np.array(speech_data["audio"]) | |
rate = speech_data["sampling_rate"] | |
# Return audio, caption, extracted text, and final context | |
return (rate, audio), caption, extracted_text, final_context | |
except Exception as e: | |
return None, f"Error: {str(e)}", "", "" | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type='pil', label="Upload an Image"), | |
outputs=[ | |
gr.Audio(label="Generated Audio"), | |
gr.Textbox(label="Generated Caption"), | |
gr.Textbox(label="Extracted Text (OCR)"), | |
gr.Textbox(label="QA-derived Context") | |
], | |
title="Contextual Image QA with SpeechT5", | |
description=( | |
"1) Generate a caption via BLIP. " | |
"2) Extract text using Florence-2. " | |
"3) Use QA with mDeBERTa to find a 'context' from caption + text. " | |
"4) Convert it to audio via SpeechT5." | |
), | |
) | |
iface.launch() | |