Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
146b507
1
Parent(s):
4a30ac2
Fix token_type_ids error and add pydantic dependency
Browse files
app.py
CHANGED
@@ -82,12 +82,17 @@ def generate_response(
|
|
82 |
|
83 |
# Tokenize input
|
84 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
# Generate response
|
88 |
with torch.no_grad():
|
89 |
outputs = model.generate(
|
90 |
-
**
|
91 |
max_new_tokens=max_new_tokens,
|
92 |
temperature=temperature,
|
93 |
top_p=top_p,
|
|
|
82 |
|
83 |
# Tokenize input
|
84 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
85 |
+
|
86 |
+
# Move to device and filter out token_type_ids if present
|
87 |
+
model_inputs = {}
|
88 |
+
for k, v in inputs.items():
|
89 |
+
if k != 'token_type_ids': # Filter out token_type_ids
|
90 |
+
model_inputs[k] = v.to(model.device)
|
91 |
|
92 |
# Generate response
|
93 |
with torch.no_grad():
|
94 |
outputs = model.generate(
|
95 |
+
**model_inputs,
|
96 |
max_new_tokens=max_new_tokens,
|
97 |
temperature=temperature,
|
98 |
top_p=top_p,
|