Spaces:
Running
Running
import gradio as gr | |
from transformers import ( | |
pipeline, | |
AutoProcessor, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
GenerationConfig, | |
set_seed | |
) | |
from datasets import load_dataset | |
import torch | |
import numpy as np | |
set_seed(42) | |
# Device and dtype | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
# Load image captioning model | |
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") | |
# Load SpeechT5 text-to-speech model | |
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts") | |
# Load OCR model (Florence-2) | |
ocr_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Florence-2-large", torch_dtype=dtype, trust_remote_code=True | |
).to(device) | |
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True) | |
# Load Doge-320M-Instruct for context generation | |
doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct") | |
doge_model = AutoModelForCausalLM.from_pretrained( | |
"SmallDoge/Doge-320M-Instruct", trust_remote_code=True | |
).to(device) | |
doge_generation_config = GenerationConfig( | |
max_new_tokens=100, | |
use_cache=True, | |
do_sample=True, | |
temperature=0.8, | |
top_p=0.9, | |
repetition_penalty=1.0 | |
) | |
# Load speaker embedding with fallback | |
speaker_embedding = None | |
embedding_data = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") | |
for entry in embedding_data: | |
vec = entry["xvector"] | |
if len(vec) >= 600: | |
speaker_embedding = torch.tensor(vec[:600], dtype=torch.float32).unsqueeze(0) | |
break | |
# Fallback: use a zero vector if none found | |
if speaker_embedding is None: | |
print("⚠️ No suitable speaker embedding found. Using default 600-dim zero vector.") | |
speaker_embedding = torch.zeros(1, 600, dtype=torch.float32) | |
# Ensure correct shape | |
assert speaker_embedding.shape == (1, 600), f"Expected shape (1, 600), got {speaker_embedding.shape}" | |
def process_image(image): | |
try: | |
# 1. Caption the image | |
caption = caption_model(image)[0]['generated_text'] | |
# 2. OCR with Florence-2 | |
inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device, dtype) | |
generated_ids = ocr_model.generate( | |
input_ids=inputs["input_ids"], | |
pixel_values=inputs["pixel_values"], | |
max_new_tokens=4096, | |
num_beams=3, | |
do_sample=False | |
) | |
extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
# 3. Prompt Doge model for context generation | |
prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:" | |
prompt = prompt[:600] # Prevent tensor mismatch error | |
conversation = [{"role": "user", "content": prompt}] | |
doge_inputs = doge_tokenizer.apply_chat_template( | |
conversation=conversation, | |
tokenize=True, | |
return_tensors="pt" | |
).to(device) | |
doge_output = doge_model.generate( | |
input_ids=doge_inputs, | |
generation_config=doge_generation_config | |
) | |
context = doge_tokenizer.decode(doge_output[0], skip_special_tokens=True).strip() | |
# 4. Convert context to speech | |
speech = synthesiser( | |
context, | |
forward_params={"speaker_embeddings": speaker_embedding} | |
) | |
audio = np.array(speech["audio"]) | |
rate = speech["sampling_rate"] | |
return (rate, audio), caption, extracted_text, context | |
except Exception as e: | |
return None, f"Error: {str(e)}", "", "" | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=gr.Image(type='pil', label="Upload an Image"), | |
outputs=[ | |
gr.Audio(label="Generated Audio"), | |
gr.Textbox(label="Generated Caption"), | |
gr.Textbox(label="Extracted Text (OCR)"), | |
gr.Textbox(label="Generated Context") | |
], | |
title="SeeSay Contextualizer with Doge-320M", | |
description="Upload an image to caption it, extract text, generate context, and hear the result as speech." | |
) | |
iface.launch(share=True) |