File size: 3,168 Bytes
f30b843
e6ea13d
 
602e80d
 
608498c
e6ea13d
 
 
 
ed4af8f
83cd235
e6ea13d
 
 
 
efa273d
 
e6ea13d
efa273d
 
e6ea13d
 
e9b130c
efa273d
602e80d
 
 
e6ea13d
602e80d
629e04f
e6ea13d
ed4af8f
602e80d
e6ea13d
efa273d
 
 
 
 
 
 
 
 
 
e6ea13d
 
 
 
f7203e8
e6ea13d
 
 
f7203e8
 
e9b130c
e6ea13d
 
 
602e80d
e6ea13d
 
602e80d
629e04f
e9b130c
26dbd13
e6ea13d
602e80d
26dbd13
602e80d
ed4af8f
602e80d
 
efa273d
e9b130c
e6ea13d
602e80d
e6ea13d
 
629e04f
5c86456
ed4af8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer, set_seed
from datasets import load_dataset
import torch
import numpy as np

# Set seed for reproducibility
set_seed(42)

# Load BLIP model for image captioning
caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")

# Load SpeechT5 model for text-to-speech
synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")

# Load Florence-2 model for OCR
ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)

# Load GPT-2 (124M) model for text generation
gpt2_generator = pipeline('text-generation', model='gpt2')

# Load speaker embedding
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)


def process_image(image):
    try:
        # Generate caption from the image
        caption = caption_model(image)[0]['generated_text']

        # Extract text (OCR) using Florence-2
        inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
        generated_ids = ocr_model.generate(
            input_ids=inputs["input_ids"],
            pixel_values=inputs["pixel_values"],
            max_new_tokens=4096,
            num_beams=3,
            do_sample=False
        )
        extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # Generate context using GPT-2 (124M)
        prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
        context_output = gpt2_generator(prompt, max_length=100, num_return_sequences=1)
        context = context_output[0]['generated_text']

        # Convert context to speech
        speech = synthesiser(
            context,
            forward_params={"speaker_embeddings": speaker_embedding}
        )

        # Prepare audio data
        audio = np.array(speech["audio"])
        rate = speech["sampling_rate"]

        # Return audio, caption, extracted text, and context
        return (rate, audio), caption, extracted_text, context

    except Exception as e:
        return None, f"Error: {str(e)}", "", ""


# Gradio Interface
iface = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type='pil', label="Upload an Image"),
    outputs=[
        gr.Audio(label="Generated Audio"),
        gr.Textbox(label="Generated Caption"),
        gr.Textbox(label="Extracted Text (OCR)"),
        gr.Textbox(label="Generated Context")
    ],
    title="SeeSay Contextualizer with GPT-2 (124M)",
    description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using GPT-2."
)

iface.launch()