File size: 3,540 Bytes
f30b843
9d274d8
602e80d
 
9d274d8
 
608498c
9d274d8
ed4af8f
83cd235
9d274d8
efa273d
 
9d274d8
 
 
 
 
efa273d
 
9d274d8
 
 
 
 
 
 
 
e9b130c
efa273d
602e80d
 
 
 
629e04f
9d274d8
ed4af8f
602e80d
9d274d8
efa273d
 
 
 
 
 
 
 
 
 
9d274d8
 
 
 
 
f7203e8
9d274d8
 
 
 
 
 
 
f7203e8
 
e9b130c
9d274d8
 
 
602e80d
9d274d8
 
602e80d
629e04f
e9b130c
26dbd13
602e80d
26dbd13
602e80d
ed4af8f
602e80d
 
efa273d
e9b130c
9d274d8
602e80d
9d274d8
 
 
 
 
 
 
629e04f
5c86456
ed4af8f
c0daae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
import torch
import numpy as np
from datasets import load_dataset
from PIL import Image

# 1) IMAGE CAPTIONING MODEL
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

# 2) OCR MODEL (Florence-2)
ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
ocr_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-large",
    torch_dtype=ocr_dtype,
    trust_remote_code=True
).to(ocr_device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

# 3) QUESTION-ANSWERING MODEL
qa_model = pipeline(
    "question-answering",
    model="timpal0l/mdeberta-v3-base-squad2"
)

# 4) TEXT-TO-SPEECH MODEL
tts_pipeline = pipeline("text-to-speech", model="microsoft/speecht5_tts")

# Load speaker embedding
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)

def process_image(image):
    try:
        # 1) Generate caption from the image
        caption = caption_model(image)[0]['generated_text']

        # 2) Extract text from the image using Florence-2
        inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
        generated_ids = ocr_model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=4096,
            num_beams=3,
            do_sample=False
        )
        extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # 3) Use QA model to derive context from caption + extracted text
        # We treat the "context" string as the knowledge base and ask a question about it
        question = "What is the context of this image?"
        combined_context = f"Caption: {caption}\nExtracted Text: {extracted_text}"
        qa_result = qa_model(question=question, context=combined_context)

        # The QA model returns an extracted "answer" from the combined context
        # If the model can't find a direct span, it may return an empty string or a short phrase
        final_context = qa_result["answer"]

        # 4) Convert the final context to speech
        speech_data = tts_pipeline(
            final_context,
            forward_params={"speaker_embeddings": speaker_embedding}
        )

        # Prepare audio data for Gradio
        audio = np.array(speech_data["audio"])
        rate = speech_data["sampling_rate"]

        # Return audio, caption, extracted text, and final context
        return (rate, audio), caption, extracted_text, final_context

    except Exception as e:
        return None, f"Error: {str(e)}", "", ""

# Gradio Interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type='pil', label="Upload an Image"),
    outputs=[
        gr.Audio(label="Generated Audio"),
        gr.Textbox(label="Generated Caption"),
        gr.Textbox(label="Extracted Text (OCR)"),
        gr.Textbox(label="QA-derived Context")
    ],
    title="Contextual Image QA with SpeechT5",
    description=(
        "1) Generate a caption via BLIP. "
        "2) Extract text using Florence-2. "
        "3) Use QA with mDeBERTa to find a 'context' from caption + text. "
        "4) Convert it to audio via SpeechT5."
    ),
)

iface.launch()