Sachi Wagaarachchi
commited on
Commit
·
0655268
1
Parent(s):
7778aa1
debug: attention
Browse files- src/chat_logic.py +11 -3
src/chat_logic.py
CHANGED
@@ -27,11 +27,19 @@ class ChatProcessor:
|
|
27 |
skip_special_tokens=True
|
28 |
)
|
29 |
|
30 |
-
|
|
|
31 |
|
32 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
generate_kwargs = {
|
34 |
-
"input_ids": input_ids,
|
|
|
35 |
"max_new_tokens": max_new_tokens,
|
36 |
"temperature": temperature,
|
37 |
"top_p": top_p,
|
|
|
27 |
skip_special_tokens=True
|
28 |
)
|
29 |
|
30 |
+
# Get full tokenizer output
|
31 |
+
tokenized_inputs = pipe.tokenizer(prompt, return_tensors="pt")
|
32 |
|
33 |
+
# Determine model device
|
34 |
+
device = pipe.model.device
|
35 |
+
|
36 |
+
# Move all tensors to the correct device
|
37 |
+
inputs_on_device = {k: v.to(device) for k, v in tokenized_inputs.items()}
|
38 |
+
|
39 |
+
# Prepare generation kwargs with attention_mask
|
40 |
generate_kwargs = {
|
41 |
+
"input_ids": inputs_on_device["input_ids"],
|
42 |
+
"attention_mask": inputs_on_device["attention_mask"],
|
43 |
"max_new_tokens": max_new_tokens,
|
44 |
"temperature": temperature,
|
45 |
"top_p": top_p,
|