Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,7 +11,7 @@ from threading import Thread
|
|
| 11 |
import torch
|
| 12 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 13 |
|
| 14 |
-
|
| 15 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
| 16 |
|
| 17 |
|
|
@@ -34,8 +34,8 @@ def predict(message, history, system_prompt, temperature, max_tokens):
|
|
| 34 |
input_ids = enc.input_ids
|
| 35 |
attention_mask = enc.attention_mask
|
| 36 |
|
| 37 |
-
if input_ids.shape[1] >
|
| 38 |
-
input_ids = input_ids[:, -
|
| 39 |
|
| 40 |
input_ids = input_ids.to(device)
|
| 41 |
attention_mask = attention_mask.to(device)
|
|
|
|
| 11 |
import torch
|
| 12 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
| 13 |
|
| 14 |
+
MAX_LENGTH = 4096
|
| 15 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
| 16 |
|
| 17 |
|
|
|
|
| 34 |
input_ids = enc.input_ids
|
| 35 |
attention_mask = enc.attention_mask
|
| 36 |
|
| 37 |
+
if input_ids.shape[1] > MAX_LENGTH:
|
| 38 |
+
input_ids = input_ids[:, -MAX_LENGTH:]
|
| 39 |
|
| 40 |
input_ids = input_ids.to(device)
|
| 41 |
attention_mask = attention_mask.to(device)
|