File size: 3,531 Bytes
03836f6
d881f5c
d405ed8
df690e6
 
d405ed8
 
 
587fb3d
d405ed8
d881f5c
d405ed8
 
 
 
 
 
 
 
df690e6
d405ed8
df690e6
d881f5c
d405ed8
d881f5c
d405ed8
d881f5c
 
 
 
 
d405ed8
 
 
d881f5c
 
 
d405ed8
d881f5c
 
 
 
 
d405ed8
d881f5c
d405ed8
d881f5c
df690e6
d881f5c
 
 
 
d405ed8
df690e6
d405ed8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d881f5c
 
 
d405ed8
d881f5c
 
d405ed8
d881f5c
d405ed8
d881f5c
 
 
 
df690e6
d405ed8
d881f5c
df690e6
d881f5c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import json
import requests
from PIL import Image
import torch
import gradio as gr
from ppt_parser import transfer_to_structure
from transformers import AutoProcessor, Llama4ForConditionalGeneration

# βœ… Hugging Face token
hf_token = os.getenv("HF_TOKEN")
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"

# βœ… Load model & processor
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = Llama4ForConditionalGeneration.from_pretrained(
    model_id,
    token=hf_token,
    attn_implementation="flex_attention",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# βœ… Global storage
extracted_text = ""
image_paths = []

def extract_text_from_pptx_json(parsed_json: dict) -> str:
    text = ""
    for slide in parsed_json.values():
        for shape in slide.values():
            if shape.get("type") == "group":
                for group_shape in shape.get("group_content", {}).values():
                    if group_shape.get("type") == "text":
                        for para_key, para in group_shape.items():
                            if para_key.startswith("paragraph_"):
                                text += para.get("text", "") + "\n"
            elif shape.get("type") == "text":
                for para_key, para in shape.items():
                    if para_key.startswith("paragraph_"):
                        text += para.get("text", "") + "\n"
    return text.strip()

# βœ… Handle uploaded PPTX
def handle_pptx_upload(pptx_file):
    global extracted_text, image_paths
    tmp_path = pptx_file.name
    parsed_json_str, image_paths = transfer_to_structure(tmp_path, "images")
    parsed_json = json.loads(parsed_json_str)
    extracted_text = extract_text_from_pptx_json(parsed_json)
    return extracted_text or "No readable text found in slides."

# βœ… Multimodal Q&A using Scout
def ask_llama(question):
    global extracted_text, image_paths

    if not extracted_text and not image_paths:
        return "Please upload and extract a PPTX first."

    # 🧠 Build multimodal chat messages
    messages = [
        {
            "role": "user",
            "content": [],
        }
    ]

    # Add up to 2 images to prevent OOM
    for path in image_paths[:2]:
        messages[0]["content"].append({"type": "image", "image": Image.open(path)})

    messages[0]["content"].append({
        "type": "text",
        "text": f"{extracted_text}\n\nQuestion: {question}"
    })

    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(model.device)

    outputs = model.generate(**inputs, max_new_tokens=256)

    response = processor.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:])[0]
    return response.strip()

# βœ… Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Multimodal Llama 4 Scout Study Assistant")

    pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
    extract_btn = gr.Button("πŸ“œ Extract Text + Images")

    extracted_output = gr.Textbox(label="πŸ“„ Slide Text", lines=10, interactive=False)

    extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])

    question = gr.Textbox(label="❓ Ask a Question")
    ask_btn = gr.Button("πŸ’¬ Ask Llama 4 Scout")
    ai_answer = gr.Textbox(label="πŸ€– Answer", lines=6)

    ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])

if __name__ == "__main__":
    demo.launch()