preston-cell commited on
Commit
f7203e8
·
verified ·
1 Parent(s): e9b130c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -18
app.py CHANGED
@@ -16,10 +16,9 @@ ocr_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
16
  ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
17
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
18
 
19
- # Load TxGemma model for text generation
20
- text_gen_device = "cuda:0" if torch.cuda.is_available() else "cpu"
21
- text_gen_tokenizer = AutoTokenizer.from_pretrained("google/txgemma-9b-predict")
22
- text_gen_model = AutoModelForCausalLM.from_pretrained("google/txgemma-9b-predict", device_map="auto")
23
 
24
  # Load speaker embedding
25
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
@@ -31,12 +30,6 @@ def process_image(image):
31
  # Generate caption from the image
32
  caption = caption_model(image)[0]['generated_text']
33
 
34
- # Convert caption to speech
35
- speech = synthesiser(
36
- caption,
37
- forward_params={"speaker_embeddings": speaker_embedding}
38
- )
39
-
40
  # Extract text (OCR) using Florence-2
41
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
42
  generated_ids = ocr_model.generate(
@@ -48,17 +41,23 @@ def process_image(image):
48
  )
49
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
50
 
51
- # Generate context from caption and extracted text using TxGemma
52
- prompt = f"Instructions: Determine the context of the image based on the caption and extracted text.\nCaption: {caption}\nExtracted Text: {extracted_text}\nContext:"
53
- input_ids = text_gen_tokenizer(prompt, return_tensors="pt").to(text_gen_device)
54
- outputs = text_gen_model.generate(**input_ids, max_new_tokens=50)
55
- context = text_gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
 
 
 
56
 
57
  # Prepare audio data
58
  audio = np.array(speech["audio"])
59
  rate = speech["sampling_rate"]
60
 
61
- # Return audio, caption, extracted text, and generated context
62
  return (rate, audio), caption, extracted_text, context
63
 
64
  except Exception as e:
@@ -75,9 +74,10 @@ iface = gr.Interface(
75
  gr.Textbox(label="Extracted Text (OCR)"),
76
  gr.Textbox(label="Generated Context")
77
  ],
78
- title="SeeSay Contextualizer",
79
- description="Upload an image to generate a caption, extract text, create audio, and determine the context using TxGemma."
80
  )
81
 
82
  iface.launch()
83
 
 
 
16
  ocr_model = AutoModelForCausalLM.from_pretrained("microsoft/Florence-2-large", torch_dtype=ocr_dtype, trust_remote_code=True).to(ocr_device)
17
  ocr_processor = AutoProcessor.from_pretrained("microsoft/Florence-2-large", trust_remote_code=True)
18
 
19
+ # Load Llama 3.2 model for text generation
20
+ llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
21
+ llama_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", device_map="auto", torch_dtype=torch.bfloat16)
 
22
 
23
  # Load speaker embedding
24
  embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation")
 
30
  # Generate caption from the image
31
  caption = caption_model(image)[0]['generated_text']
32
 
 
 
 
 
 
 
33
  # Extract text (OCR) using Florence-2
34
  inputs = ocr_processor(text="<OCR>", images=image, return_tensors="pt").to(ocr_device, ocr_dtype)
35
  generated_ids = ocr_model.generate(
 
41
  )
42
  extracted_text = ocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
43
 
44
+ # Generate context using Llama 3.2
45
+ llama_prompt = f"Determine the context of this image. Caption: {caption}. Extracted text: {extracted_text}. Context:"
46
+ inputs_llama = llama_tokenizer(llama_prompt, return_tensors="pt").to(llama_model.device)
47
+ llama_output_ids = llama_model.generate(**inputs_llama, max_new_tokens=100)
48
+ context = llama_tokenizer.decode(llama_output_ids[0], skip_special_tokens=True)
49
+
50
+ # Convert context to speech
51
+ speech = synthesiser(
52
+ context,
53
+ forward_params={"speaker_embeddings": speaker_embedding}
54
+ )
55
 
56
  # Prepare audio data
57
  audio = np.array(speech["audio"])
58
  rate = speech["sampling_rate"]
59
 
60
+ # Return audio, caption, extracted text, and context
61
  return (rate, audio), caption, extracted_text, context
62
 
63
  except Exception as e:
 
74
  gr.Textbox(label="Extracted Text (OCR)"),
75
  gr.Textbox(label="Generated Context")
76
  ],
77
+ title="SeeSay Contextualizer with Llama 3.2",
78
+ description="Upload an image to generate a caption, extract text, create audio from context, and determine the context using Llama 3.2."
79
  )
80
 
81
  iface.launch()
82
 
83
+