g0th commited on
Commit
df690e6
Β·
verified Β·
1 Parent(s): 3b3c05a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -44
app.py CHANGED
@@ -1,33 +1,26 @@
1
  import gradio as gr
2
  import os
3
  import json
4
- import torch
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
6
  from ppt_parser import transfer_to_structure
7
- from functools import lru_cache
 
 
8
 
9
- # βœ… Get Hugging Face token from Space Secrets
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
- # βœ… Load summarization model (BART)
13
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
14
-
15
- # βœ… Load Mistral model (memoized to avoid reloading)
16
- @lru_cache(maxsize=1)
17
- def load_mistral():
18
- tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1", token=hf_token)
19
- model = AutoModelForCausalLM.from_pretrained(
20
- "mistralai/Mistral-7B-Instruct-v0.1",
21
- torch_dtype=torch.float16,
22
- device_map="auto",
23
- token=hf_token
24
- )
25
- return pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=512)
26
-
27
- mistral_pipe = load_mistral()
28
 
29
- # βœ… Global variable to hold extracted content
30
  extracted_text = ""
 
31
 
32
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
33
  text = ""
@@ -45,47 +38,58 @@ def extract_text_from_pptx_json(parsed_json: dict) -> str:
45
  text += para.get("text", "") + "\n"
46
  return text.strip()
47
 
 
48
  def handle_pptx_upload(pptx_file):
49
- global extracted_text
50
  tmp_path = pptx_file.name
51
- parsed_json_str, _ = transfer_to_structure(tmp_path, "images")
52
  parsed_json = json.loads(parsed_json_str)
53
  extracted_text = extract_text_from_pptx_json(parsed_json)
 
54
  return extracted_text or "No readable text found in slides."
55
 
56
- def summarize_text():
57
- global extracted_text
58
- if not extracted_text:
59
- return "Please upload and extract text from a PPTX file first."
60
- summary = summarizer(extracted_text, max_length=200, min_length=50, do_sample=False)[0]['summary_text']
61
- return summary
62
-
63
- def clarify_concept(question):
64
- global extracted_text
65
- if not extracted_text:
66
- return "Please upload and extract text from a PPTX file first."
67
- prompt = f"[INST] Use the following context to answer the question:\n\n{extracted_text}\n\nQuestion: {question} [/INST]"
68
- response = mistral_pipe(prompt)[0]["generated_text"]
69
- return response.replace(prompt, "").strip()
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  # βœ… Gradio UI
72
  with gr.Blocks() as demo:
73
- gr.Markdown("## 🧠 AI-Powered Study Assistant for PowerPoint Lectures (Mistral 7B)")
74
 
75
  pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
76
- extract_btn = gr.Button("πŸ“œ Extract & Summarize")
77
 
78
  extracted_output = gr.Textbox(label="πŸ“„ Extracted Text", lines=10, interactive=False)
79
- summary_output = gr.Textbox(label="πŸ“ Summary", interactive=False)
80
 
81
  extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])
82
- extract_btn.click(summarize_text, outputs=[summary_output])
83
 
84
  question = gr.Textbox(label="❓ Ask a Question")
85
- ask_btn = gr.Button("πŸ’¬ Ask Mistral")
86
- ai_answer = gr.Textbox(label="πŸ€– Mistral Answer", lines=4)
87
 
88
- ask_btn.click(clarify_concept, inputs=[question], outputs=[ai_answer])
89
 
90
  if __name__ == "__main__":
91
  demo.launch()
 
1
  import gradio as gr
2
  import os
3
  import json
 
 
4
  from ppt_parser import transfer_to_structure
5
+ from PIL import Image
6
+ import torch
7
+ from transformers import AutoProcessor, AutoModelForImageTextToText
8
 
9
+ # βœ… Hugging Face Token for gated model access
10
  hf_token = os.getenv("HF_TOKEN")
11
 
12
+ # βœ… Load Llama-4-Scout model and processor
13
+ processor = AutoProcessor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct", token=hf_token)
14
+ model = AutoModelForImageTextToText.from_pretrained(
15
+ "meta-llama/Llama-4-Scout-17B-16E-Instruct",
16
+ torch_dtype=torch.float16,
17
+ device_map="auto",
18
+ token=hf_token
19
+ )
 
 
 
 
 
 
 
 
20
 
21
+ # βœ… Extracted data storage
22
  extracted_text = ""
23
+ slide_images = []
24
 
25
  def extract_text_from_pptx_json(parsed_json: dict) -> str:
26
  text = ""
 
38
  text += para.get("text", "") + "\n"
39
  return text.strip()
40
 
41
+ # βœ… Handle uploaded .pptx
42
  def handle_pptx_upload(pptx_file):
43
+ global extracted_text, slide_images
44
  tmp_path = pptx_file.name
45
+ parsed_json_str, image_paths = transfer_to_structure(tmp_path, "images")
46
  parsed_json = json.loads(parsed_json_str)
47
  extracted_text = extract_text_from_pptx_json(parsed_json)
48
+ slide_images = image_paths
49
  return extracted_text or "No readable text found in slides."
50
 
51
+ # βœ… Ask a question using Llama 4 Scout
52
+ def ask_llama(question):
53
+ global extracted_text, slide_images
54
+ if not extracted_text and not slide_images:
55
+ return "Please upload a PPTX file first."
56
+
57
+ inputs = {
58
+ "role": "user",
59
+ "content": []
60
+ }
61
+
62
+ # Add first image only (multimodal models may limit batch input size)
63
+ if slide_images:
64
+ image = Image.open(slide_images[0])
65
+ inputs["content"].append({"type": "image", "image": image})
66
+
67
+ # Add contextual text + question
68
+ context = f"{extracted_text}\n\nQuestion: {question}"
69
+ inputs["content"].append({"type": "text", "text": context})
70
+
71
+ outputs = processor(text=[inputs], return_tensors="pt").to(model.device)
72
+ with torch.no_grad():
73
+ generated_ids = model.generate(**outputs, max_new_tokens=512)
74
+ result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
75
+ return result
76
 
77
  # βœ… Gradio UI
78
  with gr.Blocks() as demo:
79
+ gr.Markdown("## 🧠 Llama 4 Scout: PPTX-Based Multimodal Study Assistant")
80
 
81
  pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
82
+ extract_btn = gr.Button("πŸ“œ Extract Text + Slides")
83
 
84
  extracted_output = gr.Textbox(label="πŸ“„ Extracted Text", lines=10, interactive=False)
 
85
 
86
  extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])
 
87
 
88
  question = gr.Textbox(label="❓ Ask a Question")
89
+ ask_btn = gr.Button("πŸ’¬ Ask Llama 4 Scout")
90
+ ai_answer = gr.Textbox(label="πŸ€– Llama Answer", lines=4)
91
 
92
+ ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])
93
 
94
  if __name__ == "__main__":
95
  demo.launch()