File size: 3,699 Bytes
f30b843
8c3caa4
 
ffbd81c
d0aa231
ffbd81c
8c3caa4
ffbd81c
8c3caa4
ffbd81c
602e80d
 
e6ea13d
ffbd81c
 
 
a483c36
ffbd81c
ed4af8f
83cd235
ffbd81c
e6ea13d
 
ffbd81c
 
 
 
 
a483c36
ffbd81c
d0aa231
ffbd81c
 
 
 
8c3caa4
 
 
 
 
 
 
e9b130c
ffbd81c
 
 
 
 
 
 
 
 
 
e6ea13d
602e80d
629e04f
ffbd81c
ed4af8f
602e80d
ffbd81c
 
 
 
 
 
 
 
d0aa231
ffbd81c
 
 
 
8c3caa4
 
 
 
 
a483c36
8c3caa4
ffbd81c
 
 
8c3caa4
ffbd81c
8c3caa4
ffbd81c
 
 
 
 
e6ea13d
 
602e80d
e6ea13d
602e80d
629e04f
e9b130c
26dbd13
ffbd81c
26dbd13
602e80d
ed4af8f
602e80d
 
efa273d
e9b130c
e6ea13d
602e80d
ffbd81c
 
629e04f
5c86456
ffbd81c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
from transformers import (
    pipeline,
    AutoProcessor,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    TextStreamer
)
from datasets import load_dataset
import torch
import numpy as np

# Set device and dtype
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Image Captioning
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

# Text-to-Speech
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")

# Florence-2-base for OCR
ocr_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-base",
    torch_dtype=torch_dtype,
    trust_remote_code=True
).to(device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)

# Doge model for context generation
doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
doge_model = AutoModelForCausalLM.from_pretrained("SmallDoge/Doge-320M-Instruct", trust_remote_code=True).to(device)
doge_config = GenerationConfig(
    max_new_tokens=100,
    use_cache=True,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.0
)

# Speaker embedding (600-dim)
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
embedding = None
for entry in embeddings_dataset:
    vector = torch.tensor(entry["xvector"]).unsqueeze(0)
    if vector.shape[1] >= 600:
        embedding = vector[:, :600]
        break
if embedding is None:
    raise ValueError("No suitable speaker embedding of at least 600 dimensions found.")

def process_image(image):
    try:
        # Caption
        caption = caption_model(image)[0]['generated_text']

        # OCR
        ocr_inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device, torch_dtype)
        generated_ids = ocr_model.generate(
            input_ids=ocr_inputs["input_ids"],
            pixel_values=ocr_inputs["pixel_values"],
            max_new_tokens=1024,
            num_beams=3,
            do_sample=False
        )
        extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Doge context generation
        prompt = f"Determine the context of this image. Caption: {caption} Extracted text: {extracted_text}"
        conversation = [{"role": "user", "content": prompt}]
        doge_inputs = doge_tokenizer.apply_chat_template(
            conversation=conversation,
            tokenize=True,
            return_tensors="pt"
        ).to(device)

        outputs = doge_model.generate(
            doge_inputs,
            generation_config=doge_config
        )
        context = doge_tokenizer.decode(outputs[0], skip_special_tokens=True)

        # TTS
        speech = synthesiser(
            context,
            forward_params={"speaker_embeddings": embedding}
        )
        audio = np.array(speech["audio"])
        rate = speech["sampling_rate"]

        return (rate, audio), caption, extracted_text, context

    except Exception as e:
        return None, f"Error: {str(e)}", "", ""

# Gradio Interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type='pil', label="Upload an Image"),
    outputs=[
        gr.Audio(label="Generated Audio"),
        gr.Textbox(label="Generated Caption"),
        gr.Textbox(label="Extracted Text (OCR)"),
        gr.Textbox(label="Generated Context")
    ],
    title="SeeSay Contextualizer with Doge & BLIP",
    description="Upload an image to generate a caption, extract text, determine context, and convert it to audio."
)

iface.launch()