Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
2062515
1
Parent(s):
eb1a863
remove some unneeded lines, fix pipe issue
Browse files- utils/models.py +5 -16
utils/models.py
CHANGED
@@ -135,15 +135,10 @@ def run_inference(model_name, context, question):
|
|
135 |
# Common arguments for tokenizer loading
|
136 |
tokenizer_load_args = {"padding_side": "left", "token": True}
|
137 |
|
138 |
-
# Determine the Hugging Face model name for the tokenizer
|
139 |
actual_model_name_for_tokenizer = model_name
|
140 |
if "icecream" in model_name.lower():
|
141 |
actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
|
142 |
|
143 |
-
# Note: tokenizer_kwargs (defined earlier, with add_generation_prompt etc.)
|
144 |
-
# is intended for tokenizer.apply_chat_template, not for AutoTokenizer.from_pretrained generally.
|
145 |
-
# If a specific tokenizer (e.g., Qwen) needs special __init__ args that happen to be in tokenizer_kwargs,
|
146 |
-
# that would require more specific handling here. For now, we assume general constructor args.
|
147 |
tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
|
148 |
tokenizer_cache[model_name] = tokenizer
|
149 |
|
@@ -201,8 +196,6 @@ def run_inference(model_name, context, question):
|
|
201 |
elif "icecream" in model_name.lower():
|
202 |
|
203 |
print("ICECREAM")
|
204 |
-
# text_input is the list of messages from format_rag_prompt
|
205 |
-
# tokenizer_kwargs (e.g., {"add_generation_prompt": True}) are correctly passed to apply_chat_template
|
206 |
model_inputs = tokenizer.apply_chat_template(
|
207 |
text_input,
|
208 |
tokenize=True,
|
@@ -211,38 +204,34 @@ def run_inference(model_name, context, question):
|
|
211 |
**tokenizer_kwargs,
|
212 |
)
|
213 |
|
214 |
-
|
215 |
model_inputs = model_inputs.to(model.device)
|
216 |
|
217 |
input_ids = model_inputs.input_ids
|
218 |
-
attention_mask = model_inputs.attention_mask
|
219 |
|
220 |
-
prompt_tokens_length = input_ids.shape[1]
|
221 |
|
222 |
with torch.inference_mode():
|
223 |
# Check interrupt before generation
|
224 |
if generation_interrupt.is_set():
|
225 |
return ""
|
226 |
|
227 |
-
# Explicitly pass input_ids, attention_mask, and pad_token_id
|
228 |
-
# tokenizer.pad_token is set to tokenizer.eos_token if None, earlier in the code.
|
229 |
output_sequences = model.generate(
|
230 |
input_ids=input_ids,
|
231 |
attention_mask=attention_mask,
|
232 |
max_new_tokens=512,
|
233 |
-
eos_token_id=tokenizer.eos_token_id,
|
234 |
pad_token_id=tokenizer.pad_token_id # Addresses the warning
|
235 |
)
|
236 |
|
237 |
-
# output_sequences[0] contains the full sequence (prompt + generation)
|
238 |
-
# Decode only the newly generated tokens
|
239 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
240 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
241 |
|
242 |
else: # For other models
|
243 |
formatted = pipe.tokenizer.apply_chat_template(
|
244 |
text_input,
|
245 |
-
tokenize=
|
246 |
**tokenizer_kwargs,
|
247 |
)
|
248 |
|
|
|
135 |
# Common arguments for tokenizer loading
|
136 |
tokenizer_load_args = {"padding_side": "left", "token": True}
|
137 |
|
|
|
138 |
actual_model_name_for_tokenizer = model_name
|
139 |
if "icecream" in model_name.lower():
|
140 |
actual_model_name_for_tokenizer = "meta-llama/llama-3.2-3b-instruct"
|
141 |
|
|
|
|
|
|
|
|
|
142 |
tokenizer = AutoTokenizer.from_pretrained(actual_model_name_for_tokenizer, **tokenizer_load_args)
|
143 |
tokenizer_cache[model_name] = tokenizer
|
144 |
|
|
|
196 |
elif "icecream" in model_name.lower():
|
197 |
|
198 |
print("ICECREAM")
|
|
|
|
|
199 |
model_inputs = tokenizer.apply_chat_template(
|
200 |
text_input,
|
201 |
tokenize=True,
|
|
|
204 |
**tokenizer_kwargs,
|
205 |
)
|
206 |
|
207 |
+
|
208 |
model_inputs = model_inputs.to(model.device)
|
209 |
|
210 |
input_ids = model_inputs.input_ids
|
211 |
+
attention_mask = model_inputs.attention_mask
|
212 |
|
213 |
+
prompt_tokens_length = input_ids.shape[1]
|
214 |
|
215 |
with torch.inference_mode():
|
216 |
# Check interrupt before generation
|
217 |
if generation_interrupt.is_set():
|
218 |
return ""
|
219 |
|
|
|
|
|
220 |
output_sequences = model.generate(
|
221 |
input_ids=input_ids,
|
222 |
attention_mask=attention_mask,
|
223 |
max_new_tokens=512,
|
224 |
+
eos_token_id=tokenizer.eos_token_id,
|
225 |
pad_token_id=tokenizer.pad_token_id # Addresses the warning
|
226 |
)
|
227 |
|
|
|
|
|
228 |
generated_token_ids = output_sequences[0][prompt_tokens_length:]
|
229 |
result = tokenizer.decode(generated_token_ids, skip_special_tokens=True)
|
230 |
|
231 |
else: # For other models
|
232 |
formatted = pipe.tokenizer.apply_chat_template(
|
233 |
text_input,
|
234 |
+
tokenize=False,
|
235 |
**tokenizer_kwargs,
|
236 |
)
|
237 |
|