HemanM commited on
Commit
738a56e
·
verified ·
1 Parent(s): defaa9b

Update generate.py

Browse files
Files changed (1) hide show
  1. generate.py +4 -0
generate.py CHANGED
@@ -24,6 +24,10 @@ def generate_response(prompt, use_web=False, max_length=100, top_k=40):
24
  input_text = context + f"User: {prompt}\nAssistant:"
25
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
26
 
 
 
 
 
27
  for _ in range(max_length):
28
  with torch.no_grad():
29
  logits = model(input_ids)
 
24
  input_text = context + f"User: {prompt}\nAssistant:"
25
  input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
26
 
27
+ # Truncate to avoid exceeding model's positional encoding limit
28
+ if input_ids.size(1) > 512:
29
+ input_ids = input_ids[:, -512:]
30
+
31
  for _ in range(max_length):
32
  with torch.no_grad():
33
  logits = model(input_ids)