File size: 3,685 Bytes
f30b843
1b524ac
 
29f7833
a0edfdb
 
 
 
 
 
 
 
 
e6ea13d
a0edfdb
 
 
 
a483c36
a0edfdb
 
83cd235
a0edfdb
29f7833
a0edfdb
29f7833
ffbd81c
d0aa231
a0edfdb
 
 
 
 
 
8c3caa4
68bf04e
8c3caa4
 
 
1b524ac
8c3caa4
e9b130c
a0edfdb
 
66d96fc
a0edfdb
 
 
29f7833
602e80d
629e04f
a0edfdb
68bf04e
602e80d
a0edfdb
 
 
 
 
29f7833
 
a0edfdb
d0aa231
a0edfdb
ffbd81c
a0edfdb
 
8c3caa4
a0edfdb
 
 
29f7833
1b524ac
a0edfdb
 
 
 
e6ea13d
 
602e80d
e6ea13d
602e80d
629e04f
e9b130c
26dbd13
a0edfdb
26dbd13
602e80d
68bf04e
602e80d
 
efa273d
e9b130c
68bf04e
602e80d
a0edfdb
 
629e04f
5c86456
68bf04e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
import torch
import numpy as np
from PIL import Image
from transformers import (
    pipeline,
    AutoModelForCausalLM,
    AutoProcessor,
    AutoTokenizer,
    GenerationConfig,
    TextStreamer,
)
from datasets import load_dataset

# Use CPU if no GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Device set to use {device}")

# Load image captioning model (BLIP)
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base", device=device)

# Load OCR model (Florence-2-base)
ocr_model = AutoModelForCausalLM.from_pretrained(
    "microsoft/Florence-2-base", trust_remote_code=True, torch_dtype=dtype
).to(device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)

# Load SmallDoge model for context generation
doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
doge_model = AutoModelForCausalLM.from_pretrained(
    "SmallDoge/Doge-320M-Instruct", trust_remote_code=True
).to(device)
doge_config = GenerationConfig(
    max_new_tokens=100,
    use_cache=True,
    do_sample=True,
    temperature=0.8,
    top_p=0.9,
    repetition_penalty=1.0
)

# Load SpeechT5 for TTS
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts", device=device)

# Use known compatible 600-dim speaker embedding
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)  # Shape: [1, 600]

def process_image(image):
    try:
        # Caption generation
        caption = caption_model(image)[0]['generated_text']

        # OCR extraction
        ocr_inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device)
        ocr_outputs = ocr_model.generate(
            input_ids=ocr_inputs["input_ids"],
            pixel_values=ocr_inputs["pixel_values"],
            max_new_tokens=1024,
            num_beams=3,
            do_sample=False,
        )
        extracted_text = ocr_processor.batch_decode(ocr_outputs, skip_special_tokens=True)[0]

        # Context generation using Doge
        prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
        conversation = [{"role": "user", "content": prompt}]
        doge_inputs = doge_tokenizer.apply_chat_template(conversation, tokenize=True, return_tensors="pt").to(device)
        doge_output = doge_model.generate(doge_inputs, generation_config=doge_config)
        context = doge_tokenizer.decode(doge_output[0], skip_special_tokens=True)

        # Convert context to speech
        speech = synthesiser(
            context,
            forward_params={"speaker_embeddings": speaker_embedding}
        )
        audio = np.array(speech["audio"])
        rate = speech["sampling_rate"]

        return (rate, audio), caption, extracted_text, context

    except Exception as e:
        return None, f"Error: {str(e)}", "", ""

# Gradio UI
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type='pil', label="Upload an Image"),
    outputs=[
        gr.Audio(label="Generated Audio"),
        gr.Textbox(label="Generated Caption"),
        gr.Textbox(label="Extracted Text (OCR)"),
        gr.Textbox(label="Generated Context")
    ],
    title="SeeSay Contextualizer with Doge + SpeechT5",
    description="Upload an image to generate a caption, extract OCR text, determine context with Doge-320M, and hear it with SpeechT5."
)

iface.launch(share=True)