first init
Browse files- .gitignore +12 -0
- app.py +134 -0
- requirements.txt +7 -0
- src/chat.py +0 -0
- src/llm/__init__.py +0 -0
- src/llm/chat.py +56 -0
- src/tools/__init__.py +36 -0
- src/tools/detection_model.py +26 -0
- src/tools/segmentation_model.py +14 -0
- static/aivn_logo.png +0 -0
.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.pyc
|
3 |
+
*.pyo
|
4 |
+
*.pyd
|
5 |
+
*.pdb
|
6 |
+
*.egg-info
|
7 |
+
*.egg
|
8 |
+
*.whl
|
9 |
+
*.zip
|
10 |
+
*.tar.gz
|
11 |
+
|
12 |
+
weights/
|
app.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
import os, base64, json, uuid, torch, gradio as gr
|
3 |
+
from pathlib import Path
|
4 |
+
|
5 |
+
# === Your vision-LLM stack (imported from src/… as organised earlier) ===
|
6 |
+
from src.llm.chat import FunctionCallingChat # wrapper around Llama-3.2-1B
|
7 |
+
chatbot = FunctionCallingChat() # load once at start-up
|
8 |
+
|
9 |
+
# -------- helpers --------------------------------------------------------
|
10 |
+
def image_to_base64(image_path: str):
|
11 |
+
with open(image_path, "rb") as f:
|
12 |
+
return base64.b64encode(f.read()).decode("utf-8")
|
13 |
+
|
14 |
+
|
15 |
+
def save_uploaded_image(pil_img) -> Path:
|
16 |
+
"""Persist uploaded PIL image to ./static/ and return the file path."""
|
17 |
+
Path("static").mkdir(exist_ok=True)
|
18 |
+
filename = f"upload_{uuid.uuid4().hex[:8]}.png"
|
19 |
+
path = Path("static") / filename
|
20 |
+
pil_img.save(path)
|
21 |
+
return path
|
22 |
+
|
23 |
+
|
24 |
+
# -------- inference ------------------------------------------------------
|
25 |
+
def inference(pil_img, prompt, task):
|
26 |
+
"""
|
27 |
+
• pil_img : uploaded PIL image
|
28 |
+
• prompt : optional free-form request
|
29 |
+
• task : "Detection" | "Segmentation" | "Auto"
|
30 |
+
Returns plain-text JSON with the LLM tool-call and its results.
|
31 |
+
"""
|
32 |
+
if pil_img is None:
|
33 |
+
return "❗ Please upload an image first."
|
34 |
+
|
35 |
+
img_path = save_uploaded_image(pil_img)
|
36 |
+
|
37 |
+
# Build user message for the LLM
|
38 |
+
if task == "Detection":
|
39 |
+
user_msg = f"Please detect objects in the image '{img_path}'."
|
40 |
+
elif task == "Segmentation":
|
41 |
+
user_msg = f"Please segment objects in the image '{img_path}'."
|
42 |
+
else: # Auto / custom
|
43 |
+
prompt = prompt.strip() or "Analyse this image."
|
44 |
+
user_msg = f"{prompt} (image: '{img_path}')"
|
45 |
+
|
46 |
+
# Run chat → tool calls → tool execution
|
47 |
+
out = chatbot(user_msg)
|
48 |
+
txt = (
|
49 |
+
"### 🔧 Raw tool-call \n"
|
50 |
+
f"{out['raw_tool_call']}\n\n"
|
51 |
+
"### 📦 Tool results\n"
|
52 |
+
f"{json.dumps(out['results'], indent=2)}"
|
53 |
+
)
|
54 |
+
return txt
|
55 |
+
|
56 |
+
|
57 |
+
# -------- UI (unchanged shell) ------------------------------------------
|
58 |
+
def create_header():
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Column(scale=1):
|
61 |
+
logo_base64 = image_to_base64("static/aivn_logo.png")
|
62 |
+
gr.HTML(
|
63 |
+
f"""<img src="data:image/png;base64,{logo_base64}"
|
64 |
+
alt="Logo"
|
65 |
+
style="height:120px;width:auto;margin-right:20px;margin-bottom:20px;">"""
|
66 |
+
)
|
67 |
+
with gr.Column(scale=4):
|
68 |
+
gr.Markdown(
|
69 |
+
"""
|
70 |
+
<div style="display:flex;justify-content:space-between;align-items:center;padding:0 15px;">
|
71 |
+
<div>
|
72 |
+
<h1 style="margin-bottom:0;">🖼️ Vision Tool-Calling Demo</h1>
|
73 |
+
<p style="margin-top:0.5em;color:#666;">LLM-driven Detection & Segmentation</p>
|
74 |
+
</div>
|
75 |
+
<div style="text-align:right;border-left:2px solid #ddd;padding-left:20px;">
|
76 |
+
<h3 style="margin:0;color:#2c3e50;">🚀 AIO2024 Module 10 Project 🤗</h3>
|
77 |
+
<p style="margin:0;color:#7f8c8d;">🔍 Using Llama 3.2-1B + YOLO + SAM</p>
|
78 |
+
</div>
|
79 |
+
</div>
|
80 |
+
"""
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def create_footer():
|
85 |
+
footer_html = """
|
86 |
+
<style>
|
87 |
+
.sticky-footer{position:fixed;bottom:0;left:0;width:100%;background:white;
|
88 |
+
padding:10px;box-shadow:0 -2px 10px rgba(0,0,0,0.1);z-index:1000;}
|
89 |
+
.content-wrap{padding-bottom:60px;}
|
90 |
+
</style>
|
91 |
+
<div class="sticky-footer">
|
92 |
+
<div style="text-align:center;font-size:14px;">
|
93 |
+
Created by <a href="https://vlai.work" target="_blank"
|
94 |
+
style="color:#007BFF;text-decoration:none;">VLAI</a> • AI VIETNAM
|
95 |
+
</div>
|
96 |
+
</div>
|
97 |
+
"""
|
98 |
+
return gr.HTML(footer_html)
|
99 |
+
|
100 |
+
|
101 |
+
custom_css = """
|
102 |
+
.gradio-container {min-height:100vh;}
|
103 |
+
.content-wrap {padding-bottom:60px;}
|
104 |
+
.full-width-btn {width:100%!important;height:50px!important;font-size:18px!important;
|
105 |
+
margin-top:20px!important;background:linear-gradient(45deg,#FF6B6B,#4ECDC4)!important;
|
106 |
+
color:white!important;border:none!important;}
|
107 |
+
.full-width-btn:hover {background:linear-gradient(45deg,#FF5252,#3CB4AC)!important;}
|
108 |
+
"""
|
109 |
+
|
110 |
+
with gr.Blocks(css=custom_css) as demo:
|
111 |
+
create_header()
|
112 |
+
|
113 |
+
with gr.Row(equal_height=True, variant="panel"):
|
114 |
+
with gr.Column(scale=3):
|
115 |
+
upload_image = gr.Image(label="Upload image", type="pil")
|
116 |
+
prompt_input = gr.Textbox(label="Optional prompt", placeholder="e.g. Detect cats only")
|
117 |
+
task_choice = gr.Radio(
|
118 |
+
["Auto", "Detection", "Segmentation"], value="Auto", label="Task"
|
119 |
+
)
|
120 |
+
submit_btn = gr.Button("Run 🔧", elem_classes="full-width-btn")
|
121 |
+
|
122 |
+
with gr.Column(scale=4):
|
123 |
+
output_text = gr.Markdown(label="Result")
|
124 |
+
|
125 |
+
submit_btn.click(
|
126 |
+
inference,
|
127 |
+
inputs=[upload_image, prompt_input, task_choice],
|
128 |
+
outputs=output_text,
|
129 |
+
)
|
130 |
+
|
131 |
+
create_footer()
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
demo.launch(allowed_paths=["static/aivn_logo.png", "static"])
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ultralytics==8.3.130
|
2 |
+
torch==2.6.0
|
3 |
+
transformers==4.51.3
|
4 |
+
matplotlib==3.10.3
|
5 |
+
opencv-python==4.11.0.86
|
6 |
+
gradio==5.29.0
|
7 |
+
Pillow==11.2.1
|
src/chat.py
ADDED
File without changes
|
src/llm/__init__.py
ADDED
File without changes
|
src/llm/chat.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast, sys, json, torch
|
2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
3 |
+
from ..tools import run_detection, run_segmentation, FUNCTION_SCHEMA
|
4 |
+
|
5 |
+
TOOLS = {"run_detection": run_detection, "run_segmentation": run_segmentation}
|
6 |
+
|
7 |
+
SYSTEM_PROMPT = """
|
8 |
+
You are an expert in composing functions. You are given a question and a set of possible functions.
|
9 |
+
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
10 |
+
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
11 |
+
also point it out. You should only return the function call in tools call sections.
|
12 |
+
|
13 |
+
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]\n
|
14 |
+
You SHOULD NOT include any other text in the response.
|
15 |
+
|
16 |
+
Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n""".format(functions=FUNCTION_SCHEMA)
|
17 |
+
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
class FunctionCallingChat:
|
21 |
+
def __init__(self, model_id: str = "meta-llama/Llama-3.2-1B-Instruct"):
|
22 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
|
23 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
24 |
+
model_id, device_map=device, torch_dtype=torch.bfloat16
|
25 |
+
)
|
26 |
+
|
27 |
+
def __call__(self, user_msg: str) -> dict:
|
28 |
+
messages = [
|
29 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
30 |
+
{"role": "user", "content": user_msg},
|
31 |
+
]
|
32 |
+
|
33 |
+
generation_cfg = GenerationConfig(
|
34 |
+
max_new_tokens=512, temperature=0.5, top_p=0.95, do_sample=True
|
35 |
+
)
|
36 |
+
|
37 |
+
tokenized = self.tokenizer.apply_chat_template(
|
38 |
+
messages, tokenize=True, add_generation_prompt=True,
|
39 |
+
return_attention_mask=True, return_tensors="pt"
|
40 |
+
).to(device)
|
41 |
+
|
42 |
+
output = self.model.generate(tokenized, generation_config=generation_cfg)
|
43 |
+
raw = self.tokenizer.decode(output[0], skip_special_tokens=True)
|
44 |
+
tool_calls_str = raw.split("assistant")[-1]
|
45 |
+
|
46 |
+
try:
|
47 |
+
calls = ast.literal_eval(tool_calls_str)
|
48 |
+
except Exception as e:
|
49 |
+
raise RuntimeError(f"Cannot parse tool call: {e}\nRaw: {tool_calls_str}")
|
50 |
+
|
51 |
+
results = []
|
52 |
+
for call in calls:
|
53 |
+
fn_name = call.func.id
|
54 |
+
kwargs = {kw.arg: ast.literal_eval(kw.value) for kw in call.keywords}
|
55 |
+
results.append(TOOLS[fn_name](**kwargs))
|
56 |
+
return {"raw_tool_call": tool_calls_str, "results": results}
|
src/tools/__init__.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .detection_model import run_detection
|
2 |
+
from .segmentation_model import run_segmentation
|
3 |
+
|
4 |
+
__all__ = ["run_detection", "run_segmentation"]
|
5 |
+
|
6 |
+
FUNCTION_SCHEMA = [
|
7 |
+
{
|
8 |
+
"type": "function",
|
9 |
+
"function": {
|
10 |
+
"name": "run_detection",
|
11 |
+
"description": "Detect objects in an image and return bounding boxes and labels.",
|
12 |
+
"parameters": {
|
13 |
+
"type": "object",
|
14 |
+
"properties": {
|
15 |
+
"image_path": {"type": "string", "description": "Local path to the image file."},
|
16 |
+
"is_visualize":{"type": "bool", "description": "If true draw bboxes and save next to image."}
|
17 |
+
},
|
18 |
+
"required": ["image_path"]
|
19 |
+
},
|
20 |
+
},
|
21 |
+
},
|
22 |
+
{
|
23 |
+
"type": "function",
|
24 |
+
"function": {
|
25 |
+
"name": "run_segmentation",
|
26 |
+
"description": "Segment objects in an image and return binary masks.",
|
27 |
+
"parameters": {
|
28 |
+
"type": "object",
|
29 |
+
"properties": {
|
30 |
+
"image_path": {"type": "string", "description": "Local path to the image file."}
|
31 |
+
},
|
32 |
+
"required": ["image_path"]
|
33 |
+
},
|
34 |
+
},
|
35 |
+
},
|
36 |
+
]
|
src/tools/detection_model.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from ultralytics import YOLO
|
2 |
+
|
3 |
+
detection_model_id = "yolo11n.pt"
|
4 |
+
detection_model = YOLO(detection_model_id)
|
5 |
+
def run_detection(image_path: str, is_visualize: bool = False):
|
6 |
+
"""YOLOv11: return list of {box, label, score} for a single image."""
|
7 |
+
results = detection_model(image_path)
|
8 |
+
r = results[0]
|
9 |
+
|
10 |
+
detections = []
|
11 |
+
for box in r.boxes:
|
12 |
+
# box.xyxy is a 1×4 tensor, box.conf is a 1-element tensor, box.cls likewise
|
13 |
+
coords = box.xyxy.cpu().numpy().flatten().tolist()
|
14 |
+
score = float(box.conf.cpu().numpy().item())
|
15 |
+
cls_id = int(box.cls.cpu().numpy().item())
|
16 |
+
detections.append({
|
17 |
+
"box": coords,
|
18 |
+
"label": r.names[cls_id],
|
19 |
+
"score": score,
|
20 |
+
})
|
21 |
+
|
22 |
+
if is_visualize:
|
23 |
+
r.save()
|
24 |
+
r.show()
|
25 |
+
|
26 |
+
return {"detections": detections}
|
src/tools/segmentation_model.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import SamModel, SamProcessor
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
segmentation_model_id = "facebook/sam-vit-base"
|
5 |
+
sam_processor = SamProcessor.from_pretrained(segmentation_model_id)
|
6 |
+
sam_model = SamModel.from_pretrained(segmentation_model_id)
|
7 |
+
def run_segmentation(image_path: str):
|
8 |
+
"""SAM: return binary masks as nested lists"""
|
9 |
+
img = Image.open(image_path).convert("RGB")
|
10 |
+
inputs = sam_processor(images=img, return_tensors="pt")
|
11 |
+
outputs = sam_model(**inputs)
|
12 |
+
|
13 |
+
masks = outputs.pred_masks.squeeze(0).cpu().detach().numpy().tolist()
|
14 |
+
return {"masks": masks}
|
static/aivn_logo.png
ADDED
![]() |