chrisvoncsefalvay commited on
Commit
146b507
·
1 Parent(s): 4a30ac2

Fix token_type_ids error and add pydantic dependency

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -82,12 +82,17 @@ def generate_response(
82
 
83
  # Tokenize input
84
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
85
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
 
 
 
 
86
 
87
  # Generate response
88
  with torch.no_grad():
89
  outputs = model.generate(
90
- **inputs,
91
  max_new_tokens=max_new_tokens,
92
  temperature=temperature,
93
  top_p=top_p,
 
82
 
83
  # Tokenize input
84
  inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
85
+
86
+ # Move to device and filter out token_type_ids if present
87
+ model_inputs = {}
88
+ for k, v in inputs.items():
89
+ if k != 'token_type_ids': # Filter out token_type_ids
90
+ model_inputs[k] = v.to(model.device)
91
 
92
  # Generate response
93
  with torch.no_grad():
94
  outputs = model.generate(
95
+ **model_inputs,
96
  max_new_tokens=max_new_tokens,
97
  temperature=temperature,
98
  top_p=top_p,