File size: 4,835 Bytes
811ffe8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
# app.py
import os, base64, json, uuid, torch, gradio as gr
from pathlib import Path
# === Your vision-LLM stack (imported from src/β¦ as organised earlier) ===
from src.llm.chat import FunctionCallingChat # wrapper around Llama-3.2-1B
chatbot = FunctionCallingChat() # load once at start-up
# -------- helpers --------------------------------------------------------
def image_to_base64(image_path: str):
with open(image_path, "rb") as f:
return base64.b64encode(f.read()).decode("utf-8")
def save_uploaded_image(pil_img) -> Path:
"""Persist uploaded PIL image to ./static/ and return the file path."""
Path("static").mkdir(exist_ok=True)
filename = f"upload_{uuid.uuid4().hex[:8]}.png"
path = Path("static") / filename
pil_img.save(path)
return path
# -------- inference ------------------------------------------------------
def inference(pil_img, prompt, task):
"""
β’ pil_img : uploaded PIL image
β’ prompt : optional free-form request
β’ task : "Detection" | "Segmentation" | "Auto"
Returns plain-text JSON with the LLM tool-call and its results.
"""
if pil_img is None:
return "β Please upload an image first."
img_path = save_uploaded_image(pil_img)
# Build user message for the LLM
if task == "Detection":
user_msg = f"Please detect objects in the image '{img_path}'."
elif task == "Segmentation":
user_msg = f"Please segment objects in the image '{img_path}'."
else: # Auto / custom
prompt = prompt.strip() or "Analyse this image."
user_msg = f"{prompt} (image: '{img_path}')"
# Run chat β tool calls β tool execution
out = chatbot(user_msg)
txt = (
"### π§ Raw tool-call \n"
f"{out['raw_tool_call']}\n\n"
"### π¦ Tool results\n"
f"{json.dumps(out['results'], indent=2)}"
)
return txt
# -------- UI (unchanged shell) ------------------------------------------
def create_header():
with gr.Row():
with gr.Column(scale=1):
logo_base64 = image_to_base64("static/aivn_logo.png")
gr.HTML(
f"""<img src="data:image/png;base64,{logo_base64}"
alt="Logo"
style="height:120px;width:auto;margin-right:20px;margin-bottom:20px;">"""
)
with gr.Column(scale=4):
gr.Markdown(
"""
<div style="display:flex;justify-content:space-between;align-items:center;padding:0 15px;">
<div>
<h1 style="margin-bottom:0;">πΌοΈ Vision Tool-Calling Demo</h1>
<p style="margin-top:0.5em;color:#666;">LLM-driven Detection & Segmentation</p>
</div>
<div style="text-align:right;border-left:2px solid #ddd;padding-left:20px;">
<h3 style="margin:0;color:#2c3e50;">π AIO2024 Module 10 Project π€</h3>
<p style="margin:0;color:#7f8c8d;">π Using Llama 3.2-1B + YOLO + SAM</p>
</div>
</div>
"""
)
def create_footer():
footer_html = """
<style>
.sticky-footer{position:fixed;bottom:0;left:0;width:100%;background:white;
padding:10px;box-shadow:0 -2px 10px rgba(0,0,0,0.1);z-index:1000;}
.content-wrap{padding-bottom:60px;}
</style>
<div class="sticky-footer">
<div style="text-align:center;font-size:14px;">
Created by <a href="https://vlai.work" target="_blank"
style="color:#007BFF;text-decoration:none;">VLAI</a> β’ AI VIETNAM
</div>
</div>
"""
return gr.HTML(footer_html)
custom_css = """
.gradio-container {min-height:100vh;}
.content-wrap {padding-bottom:60px;}
.full-width-btn {width:100%!important;height:50px!important;font-size:18px!important;
margin-top:20px!important;background:linear-gradient(45deg,#FF6B6B,#4ECDC4)!important;
color:white!important;border:none!important;}
.full-width-btn:hover {background:linear-gradient(45deg,#FF5252,#3CB4AC)!important;}
"""
with gr.Blocks(css=custom_css) as demo:
create_header()
with gr.Row(equal_height=True, variant="panel"):
with gr.Column(scale=3):
upload_image = gr.Image(label="Upload image", type="pil")
prompt_input = gr.Textbox(label="Optional prompt", placeholder="e.g. Detect cats only")
task_choice = gr.Radio(
["Auto", "Detection", "Segmentation"], value="Auto", label="Task"
)
submit_btn = gr.Button("Run π§", elem_classes="full-width-btn")
with gr.Column(scale=4):
output_text = gr.Markdown(label="Result")
submit_btn.click(
inference,
inputs=[upload_image, prompt_input, task_choice],
outputs=output_text,
)
create_footer()
if __name__ == "__main__":
demo.launch(allowed_paths=["static/aivn_logo.png", "static"])
|