preston-cell commited on
Commit
9d274d8
·
verified ·
1 Parent(s): c0daae4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -32
app.py CHANGED
@@ -1,38 +1,42 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer, set_seed
3
- from datasets import load_dataset
4
  import torch
5
  import numpy as np
 
 
6
 
7
- # Set seed for reproducibility
8
- set_seed(42)
9
-
10
- # Load BLIP model for image captioning
11
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
12
 
13
- # Load SpeechT5 model for text-to-speech
14
- synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
15
-
16
- # Load Florence-2 model for OCR
17
  ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
18
  ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
19
- ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
 
 
 
 
20
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
21
 
22
- # Load GPT-2 (124M) model for text generation
23
- gpt2_generator = pipeline('text-generation', model='gpt2')
 
 
 
 
 
 
24
 
25
  # Load speaker embedding
26
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
27
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
28
 
29
-
30
  def process_image(image):
31
  try:
32
- # Generate caption from the image
33
  caption = caption_model(image)[0]['generated_text']
34
 
35
- # Extract text (OCR) using Florence-2
36
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
37
  generated_ids = ocr_model.generate(
38
  input_ids=inputs["input_ids"],
@@ -43,28 +47,32 @@ def process_image(image):
43
  )
44
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
45
 
46
- # Generate context using GPT-2 (124M)
47
- prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
48
- context_output = gpt2_generator(prompt, max_length=100, num_return_sequences=1)
49
- context = context_output[0]['generated_text']
 
50
 
51
- # Convert context to speech
52
- speech = synthesiser(
53
- context,
 
 
 
 
54
  forward_params={"speaker_embeddings": speaker_embedding}
55
  )
56
 
57
- # Prepare audio data
58
- audio = np.array(speech["audio"])
59
- rate = speech["sampling_rate"]
60
 
61
- # Return audio, caption, extracted text, and context
62
- return (rate, audio), caption, extracted_text, context
63
 
64
  except Exception as e:
65
  return None, f"Error: {str(e)}", "", ""
66
 
67
-
68
  # Gradio Interface
69
  iface = gr.Interface(
70
  fn=process_image,
@@ -73,10 +81,15 @@ iface = gr.Interface(
73
  gr.Audio(label="Generated Audio"),
74
  gr.Textbox(label="Generated Caption"),
75
  gr.Textbox(label="Extracted Text (OCR)"),
76
- gr.Textbox(label="Generated Context")
77
  ],
78
- title="SeeSay Contextualizer with GPT-2 (124M)",
79
- description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using GPT-2."
 
 
 
 
 
80
  )
81
 
82
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM
 
3
  import torch
4
  import numpy as np
5
+ from datasets import load_dataset
6
+ from PIL import Image
7
 
8
+ # 1) IMAGE CAPTIONING MODEL
 
 
 
9
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
10
 
11
+ # 2) OCR MODEL (Florence-2)
 
 
 
12
  ocr_device = "cuda:0" if torch.cuda.is_available() else "cpu"
13
  ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
14
+ ocr_model = AutoModelForCausalLM.from_pretrained(
15
+ "microsoft/Florence-2-large",
16
+ torch_dtype=ocr_dtype,
17
+ trust_remote_code=True
18
+ ).to(ocr_device)
19
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
20
 
21
+ # 3) QUESTION-ANSWERING MODEL
22
+ qa_model = pipeline(
23
+ "question-answering",
24
+ model="timpal0l/mdeberta-v3-base-squad2"
25
+ )
26
+
27
+ # 4) TEXT-TO-SPEECH MODEL
28
+ tts_pipeline = pipeline("text-to-speech", model="microsoft/speecht5_tts")
29
 
30
  # Load speaker embedding
31
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
32
  speaker_embedding = torch.tensor(embeddings_dataset[7306]["xvector"]).unsqueeze(0)
33
 
 
34
  def process_image(image):
35
  try:
36
+ # 1) Generate caption from the image
37
  caption = caption_model(image)[0]['generated_text']
38
 
39
+ # 2) Extract text from the image using Florence-2
40
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
41
  generated_ids = ocr_model.generate(
42
  input_ids=inputs["input_ids"],
 
47
  )
48
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
49
 
50
+ # 3) Use QA model to derive context from caption + extracted text
51
+ # We treat the "context" string as the knowledge base and ask a question about it
52
+ question = "What is the context of this image?"
53
+ combined_context = f"Caption: {caption}\nExtracted Text: {extracted_text}"
54
+ qa_result = qa_model(question=question, context=combined_context)
55
 
56
+ # The QA model returns an extracted "answer" from the combined context
57
+ # If the model can't find a direct span, it may return an empty string or a short phrase
58
+ final_context = qa_result["answer"]
59
+
60
+ # 4) Convert the final context to speech
61
+ speech_data = tts_pipeline(
62
+ final_context,
63
  forward_params={"speaker_embeddings": speaker_embedding}
64
  )
65
 
66
+ # Prepare audio data for Gradio
67
+ audio = np.array(speech_data["audio"])
68
+ rate = speech_data["sampling_rate"]
69
 
70
+ # Return audio, caption, extracted text, and final context
71
+ return (rate, audio), caption, extracted_text, final_context
72
 
73
  except Exception as e:
74
  return None, f"Error: {str(e)}", "", ""
75
 
 
76
  # Gradio Interface
77
  iface = gr.Interface(
78
  fn=process_image,
 
81
  gr.Audio(label="Generated Audio"),
82
  gr.Textbox(label="Generated Caption"),
83
  gr.Textbox(label="Extracted Text (OCR)"),
84
+ gr.Textbox(label="QA-derived Context")
85
  ],
86
+ title="Contextual Image QA with SpeechT5",
87
+ description=(
88
+ "1) Generate a caption via BLIP. "
89
+ "2) Extract text using Florence-2. "
90
+ "3) Use QA with mDeBERTa to find a 'context' from caption + text. "
91
+ "4) Convert it to audio via SpeechT5."
92
+ ),
93
  )
94
 
95
  iface.launch()