preston-cell commited on
Commit
f9a1f04
·
verified ·
1 Parent(s): d1f1b90

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -14
app.py CHANGED
@@ -1,9 +1,12 @@
1
  import gradio as gr
2
- from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer
3
  from datasets import load_dataset
4
  import torch
5
  import numpy as np
6
 
 
 
 
7
  # Load BLIP model for image captioning
8
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
9
 
@@ -16,10 +19,8 @@ ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
  ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
17
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
18
 
19
- # Load QwQ model for text generation
20
- qwq_model_name = "Qwen/QwQ-32B"
21
- qwq_model = AutoModelForCausalLM.from_pretrained(qwq_model_name, torch_dtype="auto", device_map="auto")
22
- qwq_tokenizer = AutoTokenizer.from_pretrained(qwq_model_name)
23
 
24
  # Load speaker embedding
25
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
@@ -42,12 +43,10 @@ def process_image(image):
42
  )
43
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
44
 
45
- # Generate context using QwQ
46
- messages = [{"role": "user", "content": f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}."}]
47
- qwq_prompt = qwq_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
48
- inputs_qwq = qwq_tokenizer(qwq_prompt, return_tensors="pt").to(qwq_model.device)
49
- qwq_output_ids = qwq_model.generate(**inputs_qwq, max_new_tokens=100)
50
- context = qwq_tokenizer.batch_decode(qwq_output_ids[:, inputs_qwq.input_ids.shape[-1]:], skip_special_tokens=True)[0]
51
 
52
  # Convert context to speech
53
  speech = synthesiser(
@@ -76,9 +75,8 @@ iface = gr.Interface(
76
  gr.Textbox(label="Extracted Text (OCR)"),
77
  gr.Textbox(label="Generated Context")
78
  ],
79
- title="SeeSay Contextualizer with QwQ-32B",
80
- description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using QwQ-32B."
81
  )
82
 
83
  iface.launch()
84
-
 
1
  import gradio as gr
2
+ from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, AutoTokenizer, set_seed
3
  from datasets import load_dataset
4
  import torch
5
  import numpy as np
6
 
7
+ # Set seed for reproducibility
8
+ set_seed(42)
9
+
10
  # Load BLIP model for image captioning
11
  caption_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
12
 
 
19
  ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
20
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
21
 
22
+ # Load GPT-2 XL model for text generation
23
+ gpt2_generator = pipeline('text-generation', model='gpt2-xl')
 
 
24
 
25
  # Load speaker embedding
26
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
 
43
  )
44
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
45
 
46
+ # Generate context using GPT-2 XL
47
+ prompt = f"Determine the context of this image based on the caption and extracted text. Caption: {caption}. Extracted text: {extracted_text}. Context:"
48
+ context_output = gpt2_generator(prompt, max_length=150, num_return_sequences=1)
49
+ context = context_output[0]['generated_text']
 
 
50
 
51
  # Convert context to speech
52
  speech = synthesiser(
 
75
  gr.Textbox(label="Extracted Text (OCR)"),
76
  gr.Textbox(label="Generated Context")
77
  ],
78
+ title="SeeSay Contextualizer with GPT-2 XL",
79
+ description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using GPT-2 XL."
80
  )
81
 
82
  iface.launch()