File size: 4,892 Bytes
811ffe8
 
0f47a3e
811ffe8
0f47a3e
 
811ffe8
 
 
 
 
 
0f47a3e
811ffe8
 
 
 
 
 
0f47a3e
811ffe8
 
 
0f47a3e
 
811ffe8
0f47a3e
811ffe8
 
 
 
0f47a3e
811ffe8
 
 
0f47a3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
811ffe8
 
 
 
 
0f47a3e
811ffe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f47a3e
811ffe8
 
 
 
 
 
 
0f47a3e
 
 
 
 
 
 
811ffe8
 
 
 
 
 
 
0f47a3e
811ffe8
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os, base64, json, uuid, torch, gradio as gr
from pathlib import Path
from src.llm.chat import FunctionCallingChat

chatbot = FunctionCallingChat()          
chatbot.temperature = 0.7            

def image_to_base64(image_path: str):
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")

def save_uploaded_image(pil_img) -> Path:
    """Save PIL image to ./static and return its path."""
    Path("static").mkdir(exist_ok=True)
    filename = f"upload_{uuid.uuid4().hex[:8]}.png"
    path = Path("static") / filename
    pil_img.save(path)
    return path

def inference(pil_img, prompt, task, temperature):
    if pil_img is None:
        return "❗ Please upload an image first."

    img_path = save_uploaded_image(pil_img)          
    chatbot.temperature = temperature               

    # build user message
    if task == "Detection":
        user_msg = f"Please detect objects in the image '{img_path}'."
    elif task == "Segmentation":
        user_msg = f"Please segment objects in the image '{img_path}'."
    else:
        prompt = prompt.strip() or "Analyse this image."
        user_msg = f"{prompt} (image: '{img_path}')"

    try:
        out = chatbot(user_msg)                      
        txt = (
            "### πŸ”§ Raw tool-call\n"
            f"{out['raw_tool_call']}\n\n"
            "### πŸ“¦ Tool results\n"
            f"{json.dumps(out['results'], indent=2)}"
        )
        return txt
    finally:
        # 4️⃣ always delete the temp image
        try:
            img_path.unlink(missing_ok=True)
        except Exception:
            pass   # if deletion fails we just move on

def create_header():
    with gr.Row():
        with gr.Column(scale=1):
            logo_base64 = image_to_base64("static/aivn_logo.png")
            gr.HTML(
                f"""<img src="data:image/png;base64,{logo_base64}"
                        alt="Logo"
                        style="height:120px;width:auto;margin-right:20px;margin-bottom:20px;">"""
            )
        with gr.Column(scale=4):
            gr.Markdown(
                """
<div style="display:flex;justify-content:space-between;align-items:center;padding:0 15px;">
  <div>
    <h1 style="margin-bottom:0;">πŸ–ΌοΈ Vision Tool-Calling Demo</h1>
    <p style="margin-top:0.5em;color:#666;">LLM-driven Detection & Segmentation</p>
  </div>
  <div style="text-align:right;border-left:2px solid #ddd;padding-left:20px;">
    <h3 style="margin:0;color:#2c3e50;">πŸš€ AIO2024 Module 10 Project πŸ€—</h3>
    <p style="margin:0;color:#7f8c8d;">πŸ” Using Llama 3.2-1B + YOLO + SAM</p>
  </div>
</div>
"""
            )

def create_footer():
    footer_html = """
<style>
  .sticky-footer{position:fixed;bottom:0;left:0;width:100%;background:white;
                 padding:10px;box-shadow:0 -2px 10px rgba(0,0,0,0.1);z-index:1000;}
  .content-wrap{padding-bottom:60px;}
</style>
<div class="sticky-footer">
  <div style="text-align:center;font-size:14px;">
    Created by <a href="https://vlai.work" target="_blank"
    style="color:#007BFF;text-decoration:none;">VLAI</a> β€’ AI VIETNAM
  </div>
</div>
"""
    return gr.HTML(footer_html)

custom_css = """
.gradio-container {min-height:100vh;}
.content-wrap {padding-bottom:60px;}
.full-width-btn {width:100%!important;height:50px!important;font-size:18px!important;
                 margin-top:20px!important;background:linear-gradient(45deg,#FF6B6B,#4ECDC4)!important;
                 color:white!important;border:none!important;}
.full-width-btn:hover {background:linear-gradient(45deg,#FF5252,#3CB4AC)!important;}
"""

# ────────────────────────────  Blocks  ─────────────────────────
with gr.Blocks(css=custom_css) as demo:
    create_header()

    with gr.Row(equal_height=True, variant="panel"):
        with gr.Column(scale=3):
            upload_image = gr.Image(label="Upload image", type="pil")
            prompt_input = gr.Textbox(label="Optional prompt", placeholder="e.g. Detect cats only")
            task_choice  = gr.Radio(["Auto", "Detection", "Segmentation"],
                                     value="Auto", label="Task")

            # NEW temperature slider
            temp_slider  = gr.Slider(minimum=0.1, maximum=1.5, step=0.1,
                                     value=0.7, label="Temperature (sampling)")

            submit_btn   = gr.Button("Run πŸ”§", elem_classes="full-width-btn")

        with gr.Column(scale=4):
            output_text = gr.Markdown(label="Result")

        submit_btn.click(
            inference,
            inputs=[upload_image, prompt_input, task_choice, temp_slider],
            outputs=output_text,
        )

    create_footer()

if __name__ == "__main__":
    demo.launch(allowed_paths=["static/aivn_logo.png", "static"])