wjnwjn59 commited on
Commit
7cc25b5
Β·
1 Parent(s): 0f47a3e

change temperature

Browse files
Files changed (2) hide show
  1. app.py +0 -1
  2. src/llm/chat.py +3 -2
app.py CHANGED
@@ -49,7 +49,6 @@ def inference(pil_img, prompt, task, temperature):
49
  except Exception:
50
  pass # if deletion fails we just move on
51
 
52
- # ──────────────────────────── UI ────────────────────────────
53
  def create_header():
54
  with gr.Row():
55
  with gr.Column(scale=1):
 
49
  except Exception:
50
  pass # if deletion fails we just move on
51
 
 
52
  def create_header():
53
  with gr.Row():
54
  with gr.Column(scale=1):
src/llm/chat.py CHANGED
@@ -18,11 +18,12 @@ Here is a list of functions in JSON format that you can invoke.\n\n{functions}\n
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  class FunctionCallingChat:
21
- def __init__(self, model_id: str = "meta-llama/Llama-3.2-1B-Instruct"):
22
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
23
  self.model = AutoModelForCausalLM.from_pretrained(
24
  model_id, device_map=device, torch_dtype=torch.bfloat16
25
  )
 
26
 
27
  def __call__(self, user_msg: str) -> dict:
28
  messages = [
@@ -31,7 +32,7 @@ class FunctionCallingChat:
31
  ]
32
 
33
  generation_cfg = GenerationConfig(
34
- max_new_tokens=128, temperature=0.2, top_p=0.95, do_sample=True
35
  )
36
 
37
  tokenized = self.tokenizer.apply_chat_template(
 
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
  class FunctionCallingChat:
21
+ def __init__(self, model_id: str = "meta-llama/Llama-3.2-1B-Instruct", temperature: float = 0.7):
22
  self.tokenizer = AutoTokenizer.from_pretrained(model_id)
23
  self.model = AutoModelForCausalLM.from_pretrained(
24
  model_id, device_map=device, torch_dtype=torch.bfloat16
25
  )
26
+ self.temperature = temperature
27
 
28
  def __call__(self, user_msg: str) -> dict:
29
  messages = [
 
32
  ]
33
 
34
  generation_cfg = GenerationConfig(
35
+ max_new_tokens=128, temperature=self.temperature, top_p=0.95, do_sample=True
36
  )
37
 
38
  tokenized = self.tokenizer.apply_chat_template(