wjnwjn59 commited on
Commit
487fa97
Β·
1 Parent(s): 6408fab

update sampling

Browse files
Files changed (3) hide show
  1. app.py +3 -20
  2. src/chat.py +0 -0
  3. src/llm/chat.py +1 -1
app.py CHANGED
@@ -1,49 +1,34 @@
1
- # app.py
2
  import os, base64, json, uuid, torch, gradio as gr
3
  from pathlib import Path
 
4
 
5
- # === Your vision-LLM stack (imported from src/… as organised earlier) ===
6
- from src.llm.chat import FunctionCallingChat # wrapper around Llama-3.2-1B
7
- chatbot = FunctionCallingChat() # load once at start-up
8
 
9
- # -------- helpers --------------------------------------------------------
10
  def image_to_base64(image_path: str):
11
  with open(image_path, "rb") as f:
12
  return base64.b64encode(f.read()).decode("utf-8")
13
 
14
-
15
  def save_uploaded_image(pil_img) -> Path:
16
- """Persist uploaded PIL image to ./static/ and return the file path."""
17
  Path("static").mkdir(exist_ok=True)
18
  filename = f"upload_{uuid.uuid4().hex[:8]}.png"
19
  path = Path("static") / filename
20
  pil_img.save(path)
21
  return path
22
 
23
-
24
- # -------- inference ------------------------------------------------------
25
  def inference(pil_img, prompt, task):
26
- """
27
- β€’ pil_img : uploaded PIL image
28
- β€’ prompt : optional free-form request
29
- β€’ task : "Detection" | "Segmentation" | "Auto"
30
- Returns plain-text JSON with the LLM tool-call and its results.
31
- """
32
  if pil_img is None:
33
  return "❗ Please upload an image first."
34
 
35
  img_path = save_uploaded_image(pil_img)
36
 
37
- # Build user message for the LLM
38
  if task == "Detection":
39
  user_msg = f"Please detect objects in the image '{img_path}'."
40
  elif task == "Segmentation":
41
  user_msg = f"Please segment objects in the image '{img_path}'."
42
- else: # Auto / custom
43
  prompt = prompt.strip() or "Analyse this image."
44
  user_msg = f"{prompt} (image: '{img_path}')"
45
 
46
- # Run chat β†’ tool calls β†’ tool execution
47
  out = chatbot(user_msg)
48
  txt = (
49
  "### πŸ”§ Raw tool-call \n"
@@ -53,8 +38,6 @@ def inference(pil_img, prompt, task):
53
  )
54
  return txt
55
 
56
-
57
- # -------- UI (unchanged shell) ------------------------------------------
58
  def create_header():
59
  with gr.Row():
60
  with gr.Column(scale=1):
 
 
1
  import os, base64, json, uuid, torch, gradio as gr
2
  from pathlib import Path
3
+ from src.llm.chat import FunctionCallingChat
4
 
5
+ chatbot = FunctionCallingChat()
 
 
6
 
 
7
  def image_to_base64(image_path: str):
8
  with open(image_path, "rb") as f:
9
  return base64.b64encode(f.read()).decode("utf-8")
10
 
 
11
  def save_uploaded_image(pil_img) -> Path:
 
12
  Path("static").mkdir(exist_ok=True)
13
  filename = f"upload_{uuid.uuid4().hex[:8]}.png"
14
  path = Path("static") / filename
15
  pil_img.save(path)
16
  return path
17
 
 
 
18
  def inference(pil_img, prompt, task):
 
 
 
 
 
 
19
  if pil_img is None:
20
  return "❗ Please upload an image first."
21
 
22
  img_path = save_uploaded_image(pil_img)
23
 
 
24
  if task == "Detection":
25
  user_msg = f"Please detect objects in the image '{img_path}'."
26
  elif task == "Segmentation":
27
  user_msg = f"Please segment objects in the image '{img_path}'."
28
+ else:
29
  prompt = prompt.strip() or "Analyse this image."
30
  user_msg = f"{prompt} (image: '{img_path}')"
31
 
 
32
  out = chatbot(user_msg)
33
  txt = (
34
  "### πŸ”§ Raw tool-call \n"
 
38
  )
39
  return txt
40
 
 
 
41
  def create_header():
42
  with gr.Row():
43
  with gr.Column(scale=1):
src/chat.py DELETED
File without changes
src/llm/chat.py CHANGED
@@ -31,7 +31,7 @@ class FunctionCallingChat:
31
  ]
32
 
33
  generation_cfg = GenerationConfig(
34
- max_new_tokens=512, temperature=0.5, top_p=0.95, do_sample=True
35
  )
36
 
37
  tokenized = self.tokenizer.apply_chat_template(
 
31
  ]
32
 
33
  generation_cfg = GenerationConfig(
34
+ max_new_tokens=512, temperature=0.2, top_p=0.95, do_sample=True
35
  )
36
 
37
  tokenized = self.tokenizer.apply_chat_template(