Commit
·
146b507
1
Parent(s):
4a30ac2
Fix token_type_ids error and add pydantic dependency
Browse files
app.py
CHANGED
|
@@ -82,12 +82,17 @@ def generate_response(
|
|
| 82 |
|
| 83 |
# Tokenize input
|
| 84 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
# Generate response
|
| 88 |
with torch.no_grad():
|
| 89 |
outputs = model.generate(
|
| 90 |
-
**
|
| 91 |
max_new_tokens=max_new_tokens,
|
| 92 |
temperature=temperature,
|
| 93 |
top_p=top_p,
|
|
|
|
| 82 |
|
| 83 |
# Tokenize input
|
| 84 |
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
|
| 85 |
+
|
| 86 |
+
# Move to device and filter out token_type_ids if present
|
| 87 |
+
model_inputs = {}
|
| 88 |
+
for k, v in inputs.items():
|
| 89 |
+
if k != 'token_type_ids': # Filter out token_type_ids
|
| 90 |
+
model_inputs[k] = v.to(model.device)
|
| 91 |
|
| 92 |
# Generate response
|
| 93 |
with torch.no_grad():
|
| 94 |
outputs = model.generate(
|
| 95 |
+
**model_inputs,
|
| 96 |
max_new_tokens=max_new_tokens,
|
| 97 |
temperature=temperature,
|
| 98 |
top_p=top_p,
|