File size: 4,247 Bytes
f30b843
8c3caa4
 
 
 
 
 
 
 
e6ea13d
602e80d
 
608498c
e6ea13d
 
a483c36
 
 
 
 
ed4af8f
83cd235
a483c36
e6ea13d
 
a483c36
8c3caa4
a483c36
 
efa273d
 
a483c36
8c3caa4
 
a483c36
 
8c3caa4
 
 
 
 
 
 
 
e9b130c
f4f3543
a483c36
 
f4f3543
a483c36
 
 
f4f3543
a483c36
f4f3543
 
a483c36
f4f3543
 
 
 
a483c36
dbabbd4
e6ea13d
602e80d
629e04f
a483c36
ed4af8f
602e80d
a483c36
 
efa273d
 
 
 
 
 
 
 
 
a483c36
8c3caa4
f4f3543
8c3caa4
 
 
 
 
a483c36
8c3caa4
a483c36
 
8c3caa4
 
a483c36
8c3caa4
a483c36
e6ea13d
 
f7203e8
 
e9b130c
e6ea13d
 
602e80d
e6ea13d
602e80d
629e04f
e9b130c
26dbd13
e6ea13d
602e80d
26dbd13
602e80d
ed4af8f
602e80d
 
efa273d
e9b130c
e6ea13d
602e80d
6a2189c
a483c36
629e04f
5c86456
a483c36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr
from transformers import (
    pipeline,
    AutoProcessor,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    set_seed
)
from datasets import load_dataset
import torch
import numpy as np

set_seed(42)

# Device and dtype
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Load image captioning model
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

# Load SpeechT5 text-to-speech model
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")

# Load OCR model (Florence-2)
ocr_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-large", torch_dtype=dtype, trust_remote_code=True
).to(device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

# Load Doge-320M-Instruct for context generation
doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
doge_model = AutoModelForCausalLM.from_pretrained(
    "SmallDoge/Doge-320M-Instruct", trust_remote_code=True
).to(device)
doge_generation_config = GenerationConfig(
    max_new_tokens=100,
    use_cache=True,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.0
)

# Load speaker embedding with fallback
speaker_embedding = None
embedding_data = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")

for entry in embedding_data:
    vec = entry["xvector"]
    if len(vec) >= 600:
        speaker_embedding = torch.tensor(vec[:600], dtype=torch.float32).unsqueeze(0)
        break

# Fallback: use a zero vector if none found
if speaker_embedding is None:
    print("⚠️ No suitable speaker embedding found. Using default 600-dim zero vector.")
    speaker_embedding = torch.zeros(1, 600, dtype=torch.float32)

# Ensure correct shape
assert speaker_embedding.shape == (1, 600), f"Expected shape (1, 600), got {speaker_embedding.shape}"


def process_image(image):
    try:
        # 1. Caption the image
        caption = caption_model(image)[0]['generated_text']

        # 2. OCR with Florence-2
        inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device, dtype)
        generated_ids = ocr_model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=4096,
            num_beams=3,
            do_sample=False
        )
        extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # 3. Prompt Doge model for context generation
        prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
        prompt = prompt[:600]  # Prevent tensor mismatch error
        conversation = [{"role": "user", "content": prompt}]
        doge_inputs = doge_tokenizer.apply_chat_template(
            conversation=conversation,
            tokenize=True,
            return_tensors="pt"
        ).to(device)

        doge_output = doge_model.generate(
            input_ids=doge_inputs,
            generation_config=doge_generation_config
        )
        context = doge_tokenizer.decode(doge_output[0], skip_special_tokens=True).strip()

        # 4. Convert context to speech
        speech = synthesiser(
            context,
            forward_params={"speaker_embeddings": speaker_embedding}
        )

        audio = np.array(speech["audio"])
        rate = speech["sampling_rate"]

        return (rate, audio), caption, extracted_text, context

    except Exception as e:
        return None, f"Error: {str(e)}", "", ""


# Gradio Interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type='pil', label="Upload an Image"),
    outputs=[
        gr.Audio(label="Generated Audio"),
        gr.Textbox(label="Generated Caption"),
        gr.Textbox(label="Extracted Text (OCR)"),
        gr.Textbox(label="Generated Context")
    ],
    title="SeeSay Contextualizer with Doge-320M",
    description="Upload an image to caption it, extract text, generate context, and hear the result as speech."
)

iface.launch(share=True)