Studymaker2 / app.py
g0th's picture
Update app.py
d405ed8 verified
raw
history blame
3.53 kB
import os
import json
import requests
from PIL import Image
import torch
import gradio as gr
from ppt_parser import transfer_to_structure
from transformers import AutoProcessor, Llama4ForConditionalGeneration
# βœ… Hugging Face token
hf_token = os.getenv("HF_TOKEN")
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
# βœ… Load model & processor
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = Llama4ForConditionalGeneration.from_pretrained(
model_id,
token=hf_token,
attn_implementation="flex_attention",
device_map="auto",
torch_dtype=torch.bfloat16,
)
# βœ… Global storage
extracted_text = ""
image_paths = []
def extract_text_from_pptx_json(parsed_json: dict) -> str:
text = ""
for slide in parsed_json.values():
for shape in slide.values():
if shape.get("type") == "group":
for group_shape in shape.get("group_content", {}).values():
if group_shape.get("type") == "text":
for para_key, para in group_shape.items():
if para_key.startswith("paragraph_"):
text += para.get("text", "") + "\n"
elif shape.get("type") == "text":
for para_key, para in shape.items():
if para_key.startswith("paragraph_"):
text += para.get("text", "") + "\n"
return text.strip()
# βœ… Handle uploaded PPTX
def handle_pptx_upload(pptx_file):
global extracted_text, image_paths
tmp_path = pptx_file.name
parsed_json_str, image_paths = transfer_to_structure(tmp_path, "images")
parsed_json = json.loads(parsed_json_str)
extracted_text = extract_text_from_pptx_json(parsed_json)
return extracted_text or "No readable text found in slides."
# βœ… Multimodal Q&A using Scout
def ask_llama(question):
global extracted_text, image_paths
if not extracted_text and not image_paths:
return "Please upload and extract a PPTX first."
# 🧠 Build multimodal chat messages
messages = [
{
"role": "user",
"content": [],
}
]
# Add up to 2 images to prevent OOM
for path in image_paths[:2]:
messages[0]["content"].append({"type": "image", "image": Image.open(path)})
messages[0]["content"].append({
"type": "text",
"text": f"{extracted_text}\n\nQuestion: {question}"
})
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
outputs = model.generate(**inputs, max_new_tokens=256)
response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
return response.strip()
# βœ… Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## 🧠 Multimodal Llama 4 Scout Study Assistant")
pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
extract_btn = gr.Button("πŸ“œ Extract Text + Images")
extracted_output = gr.Textbox(label="πŸ“„ Slide Text", lines=10, interactive=False)
extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])
question = gr.Textbox(label="❓ Ask a Question")
ask_btn = gr.Button("πŸ’¬ Ask Llama 4 Scout")
ai_answer = gr.Textbox(label="πŸ€– Answer", lines=6)
ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])
if __name__ == "__main__":
demo.launch()