preston-cell's picture
Update app.py
f4f3543 verified
raw
history blame
4.25 kB
import gradio as gr
from transformers import (
pipeline,
AutoProcessor,
AutoModelForCausalLM,
AutoTokenizer,
GenerationConfig,
set_seed
)
from datasets import load_dataset
import torch
import numpy as np
set_seed(42)
# Device and dtype
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
# Load image captioning model
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
# Load SpeechT5 text-to-speech model
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
# Load OCR model (Florence-2)
ocr_model = AutoModelForCausalLM.from_pretrained(
"microsoft/Florence-2-large", torch_dtype=dtype, trust_remote_code=True
).to(device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
# Load Doge-320M-Instruct for context generation
doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
doge_model = AutoModelForCausalLM.from_pretrained(
"SmallDoge/Doge-320M-Instruct", trust_remote_code=True
).to(device)
doge_generation_config = GenerationConfig(
max_new_tokens=100,
use_cache=True,
do_sample=True,
temperature=0.8,
top_p=0.9,
repetition_penalty=1.0
)
# Load speaker embedding with fallback
speaker_embedding = None
embedding_data = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
for entry in embedding_data:
vec = entry["xvector"]
if len(vec) >= 600:
speaker_embedding = torch.tensor(vec[:600], dtype=torch.float32).unsqueeze(0)
break
# Fallback: use a zero vector if none found
if speaker_embedding is None:
print("⚠️ No suitable speaker embedding found. Using default 600-dim zero vector.")
speaker_embedding = torch.zeros(1, 600, dtype=torch.float32)
# Ensure correct shape
assert speaker_embedding.shape == (1, 600), f"Expected shape (1, 600), got {speaker_embedding.shape}"
def process_image(image):
try:
# 1. Caption the image
caption = caption_model(image)[0]['generated_text']
# 2. OCR with Florence-2
inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device, dtype)
generated_ids = ocr_model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=4096,
num_beams=3,
do_sample=False
)
extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# 3. Prompt Doge model for context generation
prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
prompt = prompt[:600] # Prevent tensor mismatch error
conversation = [{"role": "user", "content": prompt}]
doge_inputs = doge_tokenizer.apply_chat_template(
conversation=conversation,
tokenize=True,
return_tensors="pt"
).to(device)
doge_output = doge_model.generate(
input_ids=doge_inputs,
generation_config=doge_generation_config
)
context = doge_tokenizer.decode(doge_output[0], skip_special_tokens=True).strip()
# 4. Convert context to speech
speech = synthesiser(
context,
forward_params={"speaker_embeddings": speaker_embedding}
)
audio = np.array(speech["audio"])
rate = speech["sampling_rate"]
return (rate, audio), caption, extracted_text, context
except Exception as e:
return None, f"Error: {str(e)}", "", ""
# Gradio Interface
iface = gr.Interface(
fn=process_image,
inputs=gr.Image(type='pil', label="Upload an Image"),
outputs=[
gr.Audio(label="Generated Audio"),
gr.Textbox(label="Generated Caption"),
gr.Textbox(label="Extracted Text (OCR)"),
gr.Textbox(label="Generated Context")
],
title="SeeSay Contextualizer with Doge-320M",
description="Upload an image to caption it, extract text, generate context, and hear the result as speech."
)
iface.launch(share=True)