File size: 4,835 Bytes
811ffe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# app.py
import os, base64, json, uuid, torch, gradio as gr
from pathlib import Path

# === Your vision-LLM stack (imported from src/… as organised earlier) ===
from src.llm.chat import FunctionCallingChat     # wrapper around Llama-3.2-1B
chatbot = FunctionCallingChat()                  # load once at start-up

# -------- helpers --------------------------------------------------------
def image_to_base64(image_path: str):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


def save_uploaded_image(pil_img) -> Path:
    """Persist uploaded PIL image to ./static/ and return the file path."""
    Path("static").mkdir(exist_ok=True)
    filename = f"upload_{uuid.uuid4().hex[:8]}.png"
    path = Path("static") / filename
    pil_img.save(path)
    return path


# -------- inference ------------------------------------------------------
def inference(pil_img, prompt, task):
    """
    β€’ pil_img : uploaded PIL image
    β€’ prompt  : optional free-form request
    β€’ task    : "Detection" | "Segmentation" | "Auto"
    Returns plain-text JSON with the LLM tool-call and its results.
    """
    if pil_img is None:
        return "❗ Please upload an image first."

    img_path = save_uploaded_image(pil_img)

    # Build user message for the LLM
    if task == "Detection":
        user_msg = f"Please detect objects in the image '{img_path}'."
    elif task == "Segmentation":
        user_msg = f"Please segment objects in the image '{img_path}'."
    else:  # Auto / custom
        prompt = prompt.strip() or "Analyse this image."
        user_msg = f"{prompt} (image: '{img_path}')"

    # Run chat β†’ tool calls β†’ tool execution
    out = chatbot(user_msg)
    txt = (
        "### πŸ”§ Raw tool-call \n"
        f"{out['raw_tool_call']}\n\n"
        "### πŸ“¦ Tool results\n"
        f"{json.dumps(out['results'], indent=2)}"
    )
    return txt


# -------- UI (unchanged shell) ------------------------------------------
def create_header():
    with gr.Row():
        with gr.Column(scale=1):
            logo_base64 = image_to_base64("static/aivn_logo.png")
            gr.HTML(
                f"""<img src="data:image/png;base64,{logo_base64}" 
                        alt="Logo"
                        style="height:120px;width:auto;margin-right:20px;margin-bottom:20px;">"""
            )
        with gr.Column(scale=4):
            gr.Markdown(
                """
<div style="display:flex;justify-content:space-between;align-items:center;padding:0 15px;">
  <div>
    <h1 style="margin-bottom:0;">πŸ–ΌοΈ Vision Tool-Calling Demo</h1>
    <p style="margin-top:0.5em;color:#666;">LLM-driven Detection & Segmentation</p>
  </div>
  <div style="text-align:right;border-left:2px solid #ddd;padding-left:20px;">
    <h3 style="margin:0;color:#2c3e50;">πŸš€ AIO2024 Module 10 Project πŸ€—</h3>
    <p style="margin:0;color:#7f8c8d;">πŸ” Using Llama 3.2-1B + YOLO + SAM</p>
  </div>
</div>
"""
            )


def create_footer():
    footer_html = """
<style>
  .sticky-footer{position:fixed;bottom:0;left:0;width:100%;background:white;
                 padding:10px;box-shadow:0 -2px 10px rgba(0,0,0,0.1);z-index:1000;}
  .content-wrap{padding-bottom:60px;}
</style>
<div class="sticky-footer">
  <div style="text-align:center;font-size:14px;">
    Created by <a href="https://vlai.work" target="_blank"
    style="color:#007BFF;text-decoration:none;">VLAI</a> β€’ AI VIETNAM
  </div>
</div>
"""
    return gr.HTML(footer_html)


custom_css = """
.gradio-container {min-height:100vh;}
.content-wrap {padding-bottom:60px;}
.full-width-btn {width:100%!important;height:50px!important;font-size:18px!important;
                 margin-top:20px!important;background:linear-gradient(45deg,#FF6B6B,#4ECDC4)!important;
                 color:white!important;border:none!important;}
.full-width-btn:hover {background:linear-gradient(45deg,#FF5252,#3CB4AC)!important;}
"""

with gr.Blocks(css=custom_css) as demo:
    create_header()

    with gr.Row(equal_height=True, variant="panel"):
        with gr.Column(scale=3):
            upload_image = gr.Image(label="Upload image", type="pil")
            prompt_input = gr.Textbox(label="Optional prompt", placeholder="e.g. Detect cats only")
            task_choice  = gr.Radio(
                ["Auto", "Detection", "Segmentation"], value="Auto", label="Task"
            )
            submit_btn   = gr.Button("Run πŸ”§", elem_classes="full-width-btn")

        with gr.Column(scale=4):
            output_text = gr.Markdown(label="Result")

        submit_btn.click(
            inference,
            inputs=[upload_image, prompt_input, task_choice],
            outputs=output_text,
        )

    create_footer()

if __name__ == "__main__":
    demo.launch(allowed_paths=["static/aivn_logo.png", "static"])