g0th commited on
Commit
d405ed8
Β·
verified Β·
1 Parent(s): 1dbc013

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -48
app.py CHANGED
@@ -1,93 +1,106 @@
1
- import gradio as gr
2
  import os
3
  import json
4
- from ppt_parser import transfer_to_structure
5
  from PIL import Image
6
  import torch
7
- from transformers import AutoProcessor, AutoModelForImageTextToText
 
 
8
 
9
- # βœ… Hugging Face Token for gated model access
10
  hf_token = os.getenv("HF_TOKEN")
11
-
12
- # βœ… Load Llama-4-Scout model and processor
13
- processor = AutoProcessor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct", token=hf_token)
14
- model = AutoModelForImageTextToText.from_pretrained(
15
- "meta-llama/Llama-4-Scout-17B-16E-Instruct",
16
- torch_dtype=torch.float16,
 
 
17
  device_map="auto",
18
- token=hf_token
19
  )
20
 
21
- # βœ… Extracted data storage
22
  extracted_text = ""
23
- slide_images = []
24
 
25
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
26
  text = ""
27
  for slide in parsed_json.values():
28
  for shape in slide.values():
29
- if shape.get('type') == 'group':
30
- for group_shape in shape.get('group_content', {}).values():
31
- if group_shape.get('type') == 'text':
32
  for para_key, para in group_shape.items():
33
  if para_key.startswith("paragraph_"):
34
  text += para.get("text", "") + "\n"
35
- elif shape.get('type') == 'text':
36
  for para_key, para in shape.items():
37
  if para_key.startswith("paragraph_"):
38
  text += para.get("text", "") + "\n"
39
  return text.strip()
40
 
41
- # βœ… Handle uploaded .pptx
42
  def handle_pptx_upload(pptx_file):
43
- global extracted_text, slide_images
44
  tmp_path = pptx_file.name
45
  parsed_json_str, image_paths = transfer_to_structure(tmp_path, "images")
46
  parsed_json = json.loads(parsed_json_str)
47
  extracted_text = extract_text_from_pptx_json(parsed_json)
48
- slide_images = image_paths
49
  return extracted_text or "No readable text found in slides."
50
 
51
- # βœ… Ask a question using Llama 4 Scout
52
  def ask_llama(question):
53
- global extracted_text, slide_images
54
- if not extracted_text and not slide_images:
55
- return "Please upload a PPTX file first."
56
-
57
- inputs = {
58
- "role": "user",
59
- "content": []
60
- }
61
-
62
- # Add first image only (multimodal models may limit batch input size)
63
- if slide_images:
64
- image = Image.open(slide_images[0])
65
- inputs["content"].append({"type": "image", "image": image})
66
-
67
- # Add contextual text + question
68
- context = f"{extracted_text}\n\nQuestion: {question}"
69
- inputs["content"].append({"type": "text", "text": context})
70
-
71
- outputs = processor(text=[inputs], return_tensors="pt").to(model.device)
72
- with torch.no_grad():
73
- generated_ids = model.generate(**outputs, max_new_tokens=512)
74
- result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
75
- return result
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  # βœ… Gradio UI
78
  with gr.Blocks() as demo:
79
- gr.Markdown("## 🧠 Llama 4 Scout: PPTX-Based Multimodal Study Assistant")
80
 
81
  pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
82
- extract_btn = gr.Button("πŸ“œ Extract Text + Slides")
83
 
84
- extracted_output = gr.Textbox(label="πŸ“„ Extracted Text", lines=10, interactive=False)
85
 
86
  extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])
87
 
88
  question = gr.Textbox(label="❓ Ask a Question")
89
  ask_btn = gr.Button("πŸ’¬ Ask Llama 4 Scout")
90
- ai_answer = gr.Textbox(label="πŸ€– Llama Answer", lines=4)
91
 
92
  ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])
93
 
 
 
1
  import os
2
  import json
3
+ import requests
4
  from PIL import Image
5
  import torch
6
+ import gradio as gr
7
+ from ppt_parser import transfer_to_structure
8
+ from transformers import AutoProcessor, Llama4ForConditionalGeneration
9
 
10
+ # βœ… Hugging Face token
11
  hf_token = os.getenv("HF_TOKEN")
12
+ model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
13
+
14
+ # βœ… Load model & processor
15
+ processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
16
+ model = Llama4ForConditionalGeneration.from_pretrained(
17
+ model_id,
18
+ token=hf_token,
19
+ attn_implementation="flex_attention",
20
  device_map="auto",
21
+ torch_dtype=torch.bfloat16,
22
  )
23
 
24
+ # βœ… Global storage
25
  extracted_text = ""
26
+ image_paths = []
27
 
28
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
29
  text = ""
30
  for slide in parsed_json.values():
31
  for shape in slide.values():
32
+ if shape.get("type") == "group":
33
+ for group_shape in shape.get("group_content", {}).values():
34
+ if group_shape.get("type") == "text":
35
  for para_key, para in group_shape.items():
36
  if para_key.startswith("paragraph_"):
37
  text += para.get("text", "") + "\n"
38
+ elif shape.get("type") == "text":
39
  for para_key, para in shape.items():
40
  if para_key.startswith("paragraph_"):
41
  text += para.get("text", "") + "\n"
42
  return text.strip()
43
 
44
+ # βœ… Handle uploaded PPTX
45
  def handle_pptx_upload(pptx_file):
46
+ global extracted_text, image_paths
47
  tmp_path = pptx_file.name
48
  parsed_json_str, image_paths = transfer_to_structure(tmp_path, "images")
49
  parsed_json = json.loads(parsed_json_str)
50
  extracted_text = extract_text_from_pptx_json(parsed_json)
 
51
  return extracted_text or "No readable text found in slides."
52
 
53
+ # βœ… Multimodal Q&A using Scout
54
  def ask_llama(question):
55
+ global extracted_text, image_paths
56
+
57
+ if not extracted_text and not image_paths:
58
+ return "Please upload and extract a PPTX first."
59
+
60
+ # 🧠 Build multimodal chat messages
61
+ messages = [
62
+ {
63
+ "role": "user",
64
+ "content": [],
65
+ }
66
+ ]
67
+
68
+ # Add up to 2 images to prevent OOM
69
+ for path in image_paths[:2]:
70
+ messages[0]["content"].append({"type": "image", "image": Image.open(path)})
71
+
72
+ messages[0]["content"].append({
73
+ "type": "text",
74
+ "text": f"{extracted_text}\n\nQuestion: {question}"
75
+ })
76
+
77
+ inputs = processor.apply_chat_template(
78
+ messages,
79
+ add_generation_prompt=True,
80
+ tokenize=True,
81
+ return_dict=True,
82
+ return_tensors="pt"
83
+ ).to(model.device)
84
+
85
+ outputs = model.generate(**inputs, max_new_tokens=256)
86
+
87
+ response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
88
+ return response.strip()
89
 
90
  # βœ… Gradio UI
91
  with gr.Blocks() as demo:
92
+ gr.Markdown("## 🧠 Multimodal Llama 4 Scout Study Assistant")
93
 
94
  pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
95
+ extract_btn = gr.Button("πŸ“œ Extract Text + Images")
96
 
97
+ extracted_output = gr.Textbox(label="πŸ“„ Slide Text", lines=10, interactive=False)
98
 
99
  extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])
100
 
101
  question = gr.Textbox(label="❓ Ask a Question")
102
  ask_btn = gr.Button("πŸ’¬ Ask Llama 4 Scout")
103
+ ai_answer = gr.Textbox(label="πŸ€– Answer", lines=6)
104
 
105
  ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])
106