preston-cell commited on
Commit
66d96fc
·
verified ·
1 Parent(s): eda8103

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -43
app.py CHANGED
@@ -1,79 +1,81 @@
1
  import gradio as gr
 
 
 
 
 
2
  from transformers import (
3
  pipeline,
4
  AutoProcessor,
5
  AutoModelForCausalLM,
6
  AutoTokenizer,
7
- GenerationConfig
8
  )
9
- import torch
10
- import numpy as np
11
- from PIL import Image
12
- import requests
13
- import io
14
 
15
- # Device setup
16
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
17
- torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
18
 
19
- # Load image captioning model (BLIP)
20
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
21
 
22
- # Load TTS model (SpeechT5)
23
- synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
24
-
25
  # Load Florence-2-base for OCR
26
- ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-base", torch_dtype=torch_dtype, trust_remote_code=True).to(device)
 
 
27
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
28
 
29
- # Load Doge-320M-Instruct for contextual reasoning
30
  doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
31
  doge_model = AutoModelForCausalLM.from_pretrained("SmallDoge/Doge-320M-Instruct", trust_remote_code=True).to(device)
32
 
33
- generation_config = GenerationConfig(
34
  max_new_tokens=100,
35
- use_cache=True,
36
  do_sample=True,
37
  temperature=0.8,
38
  top_p=0.9,
39
- repetition_penalty=1.0
40
  )
41
 
 
 
 
 
 
 
 
 
 
 
42
  def process_image(image):
43
  try:
44
  # Captioning
45
- caption = caption_model(image)[0]['generated_text']
46
 
47
  # OCR
48
- inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device, torch_dtype)
49
  generated_ids = ocr_model.generate(
50
  input_ids=inputs["input_ids"],
51
  pixel_values=inputs["pixel_values"],
52
  max_new_tokens=1024,
53
  num_beams=3,
54
- do_sample=False
55
  )
56
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
57
 
58
- # Context generation
59
- prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
60
  conversation = [{"role": "user", "content": prompt}]
61
- doge_inputs = doge_tokenizer.apply_chat_template(
62
- conversation=conversation,
63
- tokenize=True,
64
- return_tensors="pt"
65
- ).to(device)
66
-
67
- output_ids = doge_model.generate(
68
- doge_inputs,
69
- generation_config=generation_config
70
- )
71
-
72
  context = doge_tokenizer.decode(output_ids[0], skip_special_tokens=True)
73
 
74
- # Text-to-speech (no speaker embedding required)
75
- speech = synthesiser(context)
76
-
 
 
77
  audio = np.array(speech["audio"])
78
  rate = speech["sampling_rate"]
79
 
@@ -82,18 +84,19 @@ def process_image(image):
82
  except Exception as e:
83
  return None, f"Error: {str(e)}", "", ""
84
 
85
- # Gradio UI
 
86
  iface = gr.Interface(
87
  fn=process_image,
88
- inputs=gr.Image(type='pil', label="Upload an Image"),
89
  outputs=[
90
  gr.Audio(label="Generated Audio"),
91
  gr.Textbox(label="Generated Caption"),
92
  gr.Textbox(label="Extracted Text (OCR)"),
93
- gr.Textbox(label="Generated Context")
94
  ],
95
- title="SeeSay",
96
- description="Upload an image to generate a caption, extract text, convert context to speech, and understand the image using Doge-320M."
97
  )
98
 
99
- iface.launch(share=True)
 
1
  import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import requests
5
+ import io
6
+ from PIL import Image
7
  from transformers import (
8
  pipeline,
9
  AutoProcessor,
10
  AutoModelForCausalLM,
11
  AutoTokenizer,
12
+ GenerationConfig,
13
  )
14
+ from datasets import load_dataset
 
 
 
 
15
 
16
+ # Set device and dtype
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
19
 
20
+ # Load BLIP for image captioning
21
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
22
 
 
 
 
23
  # Load Florence-2-base for OCR
24
+ ocr_model = AutoModelForCausalLM.from_pretrained(
25
+ "microsoft/Florence-2-base", trust_remote_code=True, torch_dtype=dtype
26
+ ).to(device)
27
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-base", trust_remote_code=True)
28
 
29
+ # Load SmallDoge for context generation
30
  doge_tokenizer = AutoTokenizer.from_pretrained("SmallDoge/Doge-320M-Instruct")
31
  doge_model = AutoModelForCausalLM.from_pretrained("SmallDoge/Doge-320M-Instruct", trust_remote_code=True).to(device)
32
 
33
+ doge_config = GenerationConfig(
34
  max_new_tokens=100,
 
35
  do_sample=True,
36
  temperature=0.8,
37
  top_p=0.9,
38
+ repetition_penalty=1.0,
39
  )
40
 
41
+ # Load SpeechT5 for TTS
42
+ synthesiser = pipeline("text-to-speech", model="microsoft/speecht5_tts")
43
+
44
+ # Load speaker embedding from .npy using BytesIO
45
+ SPEAKER_EMBEDDING_URL = "https://huggingface.co/datasets/Matthijs/cmu-arctic-xvectors/resolve/main/spkemb/fn0012.npy"
46
+ response = requests.get(SPEAKER_EMBEDDING_URL)
47
+ buffer = io.BytesIO(response.content)
48
+ speaker_embedding = torch.tensor(np.load(buffer)).unsqueeze(0) # Shape: [1, 600]
49
+
50
+
51
  def process_image(image):
52
  try:
53
  # Captioning
54
+ caption = caption_model(image)[0]["generated_text"]
55
 
56
  # OCR
57
+ inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(device, dtype)
58
  generated_ids = ocr_model.generate(
59
  input_ids=inputs["input_ids"],
60
  pixel_values=inputs["pixel_values"],
61
  max_new_tokens=1024,
62
  num_beams=3,
63
+ do_sample=False,
64
  )
65
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
66
 
67
+ # Context generation with Doge
68
+ prompt = f"Determine the context of this image based on the caption and extracted text.\nCaption: {caption}\nExtracted text: {extracted_text}\nContext:"
69
  conversation = [{"role": "user", "content": prompt}]
70
+ inputs = doge_tokenizer.apply_chat_template(conversation, tokenize=True, return_tensors="pt").to(device)
71
+ output_ids = doge_model.generate(inputs, generation_config=doge_config)
 
 
 
 
 
 
 
 
 
72
  context = doge_tokenizer.decode(output_ids[0], skip_special_tokens=True)
73
 
74
+ # TTS
75
+ speech = synthesiser(
76
+ context,
77
+ forward_params={"speaker_embeddings": speaker_embedding}
78
+ )
79
  audio = np.array(speech["audio"])
80
  rate = speech["sampling_rate"]
81
 
 
84
  except Exception as e:
85
  return None, f"Error: {str(e)}", "", ""
86
 
87
+
88
+ # Gradio interface
89
  iface = gr.Interface(
90
  fn=process_image,
91
+ inputs=gr.Image(type="pil", label="Upload an Image"),
92
  outputs=[
93
  gr.Audio(label="Generated Audio"),
94
  gr.Textbox(label="Generated Caption"),
95
  gr.Textbox(label="Extracted Text (OCR)"),
96
+ gr.Textbox(label="Generated Context"),
97
  ],
98
+ title="SeeSay Contextualizer",
99
+ description="Upload an image to generate a caption, extract text (OCR), generate context using Doge, and convert to audio with SpeechT5.",
100
  )
101
 
102
+ iface.launch()