Update generate.py
Browse files- generate.py +4 -0
generate.py
CHANGED
@@ -24,6 +24,10 @@ def generate_response(prompt, use_web=False, max_length=100, top_k=40):
|
|
24 |
input_text = context + f"User: {prompt}\nAssistant:"
|
25 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
26 |
|
|
|
|
|
|
|
|
|
27 |
for _ in range(max_length):
|
28 |
with torch.no_grad():
|
29 |
logits = model(input_ids)
|
|
|
24 |
input_text = context + f"User: {prompt}\nAssistant:"
|
25 |
input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
26 |
|
27 |
+
# Truncate to avoid exceeding model's positional encoding limit
|
28 |
+
if input_ids.size(1) > 512:
|
29 |
+
input_ids = input_ids[:, -512:]
|
30 |
+
|
31 |
for _ in range(max_length):
|
32 |
with torch.no_grad():
|
33 |
logits = model(input_ids)
|