Update fine_tune_inference_test_mistral.py
Browse files
fine_tune_inference_test_mistral.py
CHANGED
@@ -76,8 +76,7 @@ def chat(msg: Message):
|
|
76 |
return {"error": "Boş giriş"}
|
77 |
|
78 |
messages = [{"role": "user", "content": user_input}]
|
79 |
-
|
80 |
-
inputs = inputs.to(model.device)
|
81 |
|
82 |
generate_args = {
|
83 |
"max_new_tokens": 128,
|
@@ -94,7 +93,7 @@ def chat(msg: Message):
|
|
94 |
})
|
95 |
|
96 |
with torch.no_grad():
|
97 |
-
output = model.generate(
|
98 |
|
99 |
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
|
100 |
answer = decoded.split("</s>")[-1].strip()
|
|
|
76 |
return {"error": "Boş giriş"}
|
77 |
|
78 |
messages = [{"role": "user", "content": user_input}]
|
79 |
+
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(model.device)
|
|
|
80 |
|
81 |
generate_args = {
|
82 |
"max_new_tokens": 128,
|
|
|
93 |
})
|
94 |
|
95 |
with torch.no_grad():
|
96 |
+
output = model.generate(input_ids=input_ids, **generate_args)
|
97 |
|
98 |
decoded = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
|
99 |
answer = decoded.split("</s>")[-1].strip()
|