File size: 3,493 Bytes
d881f5c
03836f6
d881f5c
 
df690e6
 
 
587fb3d
df690e6
d881f5c
03836f6
df690e6
 
 
 
 
 
 
 
d881f5c
df690e6
d881f5c
df690e6
d881f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df690e6
d881f5c
df690e6
d881f5c
df690e6
d881f5c
 
df690e6
d881f5c
 
df690e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d881f5c
 
 
df690e6
d881f5c
 
df690e6
d881f5c
 
 
 
 
 
df690e6
 
d881f5c
df690e6
d881f5c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import gradio as gr
import os
import json
from ppt_parser import transfer_to_structure
from PIL import Image
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText

# βœ… Hugging Face Token for gated model access
hf_token = os.getenv("HF_TOKEN")

# βœ… Load Llama-4-Scout model and processor
processor = AutoProcessor.from_pretrained("meta-llama/Llama-4-Scout-17B-16E-Instruct", token=hf_token)
model = AutoModelForImageTextToText.from_pretrained(
    "meta-llama/Llama-4-Scout-17B-16E-Instruct",
    torch_dtype=torch.float16,
    device_map="auto",
    token=hf_token
)

# βœ… Extracted data storage
extracted_text = ""
slide_images = []

def extract_text_from_pptx_json(parsed_json: dict) -> str:
    text = ""
    for slide in parsed_json.values():
        for shape in slide.values():
            if shape.get('type') == 'group':
                for group_shape in shape.get('group_content', {}).values():
                    if group_shape.get('type') == 'text':
                        for para_key, para in group_shape.items():
                            if para_key.startswith("paragraph_"):
                                text += para.get("text", "") + "\n"
            elif shape.get('type') == 'text':
                for para_key, para in shape.items():
                    if para_key.startswith("paragraph_"):
                        text += para.get("text", "") + "\n"
    return text.strip()

# βœ… Handle uploaded .pptx
def handle_pptx_upload(pptx_file):
    global extracted_text, slide_images
    tmp_path = pptx_file.name
    parsed_json_str, image_paths = transfer_to_structure(tmp_path, "images")
    parsed_json = json.loads(parsed_json_str)
    extracted_text = extract_text_from_pptx_json(parsed_json)
    slide_images = image_paths
    return extracted_text or "No readable text found in slides."

# βœ… Ask a question using Llama 4 Scout
def ask_llama(question):
    global extracted_text, slide_images
    if not extracted_text and not slide_images:
        return "Please upload a PPTX file first."

    inputs = {
        "role": "user",
        "content": []
    }

    # Add first image only (multimodal models may limit batch input size)
    if slide_images:
        image = Image.open(slide_images[0])
        inputs["content"].append({"type": "image", "image": image})

    # Add contextual text + question
    context = f"{extracted_text}\n\nQuestion: {question}"
    inputs["content"].append({"type": "text", "text": context})

    outputs = processor(text=[inputs], return_tensors="pt").to(model.device)
    with torch.no_grad():
        generated_ids = model.generate(**outputs, max_new_tokens=512)
    result = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return result

# βœ… Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("## 🧠 Llama 4 Scout: PPTX-Based Multimodal Study Assistant")

    pptx_input = gr.File(label="πŸ“‚ Upload PPTX File", file_types=[".pptx"])
    extract_btn = gr.Button("πŸ“œ Extract Text + Slides")

    extracted_output = gr.Textbox(label="πŸ“„ Extracted Text", lines=10, interactive=False)

    extract_btn.click(handle_pptx_upload, inputs=[pptx_input], outputs=[extracted_output])

    question = gr.Textbox(label="❓ Ask a Question")
    ask_btn = gr.Button("πŸ’¬ Ask Llama 4 Scout")
    ai_answer = gr.Textbox(label="πŸ€– Llama Answer", lines=4)

    ask_btn.click(ask_llama, inputs=[question], outputs=[ai_answer])

if __name__ == "__main__":
    demo.launch()