ciyidogan commited on
Commit
6d135b8
·
verified ·
1 Parent(s): d6690e0

Update fine_tune_inference_test_mistral.py

Browse files
fine_tune_inference_test_mistral.py CHANGED
@@ -76,10 +76,11 @@ def chat(msg: Message):
76
  return {"error": "Boş giriş"}
77
 
78
  messages = [{"role": "user", "content": user_input}]
79
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
 
80
 
81
  generate_args = {
82
- "max_new_tokens": 1024,
83
  "return_dict_in_generate": True,
84
  "output_scores": True,
85
  "do_sample": USE_SAMPLING
@@ -93,10 +94,11 @@ def chat(msg: Message):
93
  })
94
 
95
  with torch.no_grad():
96
- output = model.generate(input_ids=input_ids, **generate_args)
97
 
 
98
  decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
99
- answer = decoded.split("</s>")[-1].strip()
100
 
101
  if output.scores and len(output.scores) > 0:
102
  first_token_score = output.scores[0][0]
 
76
  return {"error": "Boş giriş"}
77
 
78
  messages = [{"role": "user", "content": user_input}]
79
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True)
80
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
81
 
82
  generate_args = {
83
+ "max_new_tokens": 512,
84
  "return_dict_in_generate": True,
85
  "output_scores": True,
86
  "do_sample": USE_SAMPLING
 
94
  })
95
 
96
  with torch.no_grad():
97
+ output = model.generate(**inputs, **generate_args)
98
 
99
+ prompt_text = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
100
  decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
101
+ answer = decoded.replace(prompt_text, "").strip()
102
 
103
  if output.scores and len(output.scores) > 0:
104
  first_token_score = output.scores[0][0]