oliver-aizip commited on
Commit
2062515
·
1 Parent(s): eb1a863

remove some unneeded lines, fix pipe issue

Browse files
Files changed (1) hide show
  1. utils/models.py +5 -16
utils/models.py CHANGED
@@ -135,15 +135,10 @@ def run_inference(model_name, context, question):
135
  # Common arguments for tokenizer loading
136
  tokenizer_load_args = {"padding_side": "left", "token": True}
137
 
138
- # Determine the Hugging Face model name for the tokenizer
139
  actual_model_name_for_tokenizer = model_name
140
  if "icecream" in model_name.lower():
141
  actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
142
 
143
- # Note: tokenizer_kwargs (defined earlier, with add_generation_prompt etc.)
144
- # is intended for tokenizer.apply_chat_template, not for AutoTokenizer.from_pretrained generally.
145
- # If a specific tokenizer (e.g., Qwen) needs special __init__ args that happen to be in tokenizer_kwargs,
146
- # that would require more specific handling here. For now, we assume general constructor args.
147
  tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
148
  tokenizer_cache[model_name] = tokenizer
149
 
@@ -201,8 +196,6 @@ def run_inference(model_name, context, question):
201
  elif "icecream" in model_name.lower():
202
 
203
  print("ICECREAM")
204
- # text_input is the list of messages from format_rag_prompt
205
- # tokenizer_kwargs (e.g., {"add_generation_prompt": True}) are correctly passed to apply_chat_template
206
  model_inputs = tokenizer.apply_chat_template(
207
  text_input,
208
  tokenize=True,
@@ -211,38 +204,34 @@ def run_inference(model_name, context, question):
211
  **tokenizer_kwargs,
212
  )
213
 
214
- # Move all tensors within the BatchEncoding (model_inputs) to the model's device
215
  model_inputs = model_inputs.to(model.device)
216
 
217
  input_ids = model_inputs.input_ids
218
- attention_mask = model_inputs.attention_mask # Expecting this from a correctly configured tokenizer
219
 
220
- prompt_tokens_length = input_ids.shape[1] # Get length of tokenized prompt
221
 
222
  with torch.inference_mode():
223
  # Check interrupt before generation
224
  if generation_interrupt.is_set():
225
  return ""
226
 
227
- # Explicitly pass input_ids, attention_mask, and pad_token_id
228
- # tokenizer.pad_token is set to tokenizer.eos_token if None, earlier in the code.
229
  output_sequences = model.generate(
230
  input_ids=input_ids,
231
  attention_mask=attention_mask,
232
  max_new_tokens=512,
233
- eos_token_id=tokenizer.eos_token_id, # Good practice for stopping generation
234
  pad_token_id=tokenizer.pad_token_id # Addresses the warning
235
  )
236
 
237
- # output_sequences[0] contains the full sequence (prompt + generation)
238
- # Decode only the newly generated tokens
239
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
240
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
241
 
242
  else: # For other models
243
  formatted = pipe.tokenizer.apply_chat_template(
244
  text_input,
245
- tokenize=True,
246
  **tokenizer_kwargs,
247
  )
248
 
 
135
  # Common arguments for tokenizer loading
136
  tokenizer_load_args = {"padding_side": "left", "token": True}
137
 
 
138
  actual_model_name_for_tokenizer = model_name
139
  if "icecream" in model_name.lower():
140
  actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
141
 
 
 
 
 
142
  tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
143
  tokenizer_cache[model_name] = tokenizer
144
 
 
196
  elif "icecream" in model_name.lower():
197
 
198
  print("ICECREAM")
 
 
199
  model_inputs = tokenizer.apply_chat_template(
200
  text_input,
201
  tokenize=True,
 
204
  **tokenizer_kwargs,
205
  )
206
 
207
+
208
  model_inputs = model_inputs.to(model.device)
209
 
210
  input_ids = model_inputs.input_ids
211
+ attention_mask = model_inputs.attention_mask
212
 
213
+ prompt_tokens_length = input_ids.shape[1]
214
 
215
  with torch.inference_mode():
216
  # Check interrupt before generation
217
  if generation_interrupt.is_set():
218
  return ""
219
 
 
 
220
  output_sequences = model.generate(
221
  input_ids=input_ids,
222
  attention_mask=attention_mask,
223
  max_new_tokens=512,
224
+ eos_token_id=tokenizer.eos_token_id,
225
  pad_token_id=tokenizer.pad_token_id # Addresses the warning
226
  )
227
 
 
 
228
  generated_token_ids = output_sequences[0][prompt_tokens_length:]
229
  result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
230
 
231
  else: # For other models
232
  formatted = pipe.tokenizer.apply_chat_template(
233
  text_input,
234
+ tokenize=False,
235
  **tokenizer_kwargs,
236
  )
237