Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -53,6 +53,7 @@ if torch.cuda.is_available():
|
|
53 |
def generate(
|
54 |
message: str,
|
55 |
chat_history: list[dict],
|
|
|
56 |
max_new_tokens: int = 1024,
|
57 |
temperature: float = 0.6,
|
58 |
top_p: float = 0.9,
|
@@ -61,7 +62,7 @@ def generate(
|
|
61 |
) -> Iterator[str]:
|
62 |
conversation = [*chat_history, {"role": "user", "content": message}]
|
63 |
|
64 |
-
input_ids = tokenizer.apply_chat_template(conversation, chat_template=
|
65 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
66 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
67 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
@@ -91,6 +92,7 @@ def generate(
|
|
91 |
demo = gr.ChatInterface(
|
92 |
fn=generate,
|
93 |
additional_inputs=[
|
|
|
94 |
gr.Slider(
|
95 |
label="Max new tokens",
|
96 |
minimum=1,
|
|
|
53 |
def generate(
|
54 |
message: str,
|
55 |
chat_history: list[dict],
|
56 |
+
chat_template: str,
|
57 |
max_new_tokens: int = 1024,
|
58 |
temperature: float = 0.6,
|
59 |
top_p: float = 0.9,
|
|
|
62 |
) -> Iterator[str]:
|
63 |
conversation = [*chat_history, {"role": "user", "content": message}]
|
64 |
|
65 |
+
input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, return_tensors="pt")
|
66 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
67 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
68 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
|
|
92 |
demo = gr.ChatInterface(
|
93 |
fn=generate,
|
94 |
additional_inputs=[
|
95 |
+
gr.Textbox(placeholder=CHAT_TEMPLATE, label="Chat template"),
|
96 |
gr.Slider(
|
97 |
label="Max new tokens",
|
98 |
minimum=1,
|