Sachi Wagaarachchi commited on
Commit
0655268
·
1 Parent(s): 7778aa1

debug: attention

Browse files
Files changed (1) hide show
  1. src/chat_logic.py +11 -3
src/chat_logic.py CHANGED
@@ -27,11 +27,19 @@ class ChatProcessor:
27
  skip_special_tokens=True
28
  )
29
 
30
- input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to(pipe.model.device)
 
31
 
32
- # Prepare generation kwargs
 
 
 
 
 
 
33
  generate_kwargs = {
34
- "input_ids": input_ids,
 
35
  "max_new_tokens": max_new_tokens,
36
  "temperature": temperature,
37
  "top_p": top_p,
 
27
  skip_special_tokens=True
28
  )
29
 
30
+ # Get full tokenizer output
31
+ tokenized_inputs = pipe.tokenizer(prompt, return_tensors="pt")
32
 
33
+ # Determine model device
34
+ device = pipe.model.device
35
+
36
+ # Move all tensors to the correct device
37
+ inputs_on_device = {k: v.to(device) for k, v in tokenized_inputs.items()}
38
+
39
+ # Prepare generation kwargs with attention_mask
40
  generate_kwargs = {
41
+ "input_ids": inputs_on_device["input_ids"],
42
+ "attention_mask": inputs_on_device["attention_mask"],
43
  "max_new_tokens": max_new_tokens,
44
  "temperature": temperature,
45
  "top_p": top_p,