AlexHung29629 commited on
Commit
019333f
·
verified ·
1 Parent(s): c420ecb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -53,6 +53,7 @@ if torch.cuda.is_available():
53
  def generate(
54
  message: str,
55
  chat_history: list[dict],
 
56
  max_new_tokens: int = 1024,
57
  temperature: float = 0.6,
58
  top_p: float = 0.9,
@@ -61,7 +62,7 @@ def generate(
61
  ) -> Iterator[str]:
62
  conversation = [*chat_history, {"role": "user", "content": message}]
63
 
64
- input_ids = tokenizer.apply_chat_template(conversation, chat_template=CHAT_TEMPLATE, return_tensors="pt")
65
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
66
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
67
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
@@ -91,6 +92,7 @@ def generate(
91
  demo = gr.ChatInterface(
92
  fn=generate,
93
  additional_inputs=[
 
94
  gr.Slider(
95
  label="Max new tokens",
96
  minimum=1,
 
53
  def generate(
54
  message: str,
55
  chat_history: list[dict],
56
+ chat_template: str,
57
  max_new_tokens: int = 1024,
58
  temperature: float = 0.6,
59
  top_p: float = 0.9,
 
62
  ) -> Iterator[str]:
63
  conversation = [*chat_history, {"role": "user", "content": message}]
64
 
65
+ input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, return_tensors="pt")
66
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
67
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
68
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
 
92
  demo = gr.ChatInterface(
93
  fn=generate,
94
  additional_inputs=[
95
+ gr.Textbox(placeholder=CHAT_TEMPLATE, label="Chat template"),
96
  gr.Slider(
97
  label="Max new tokens",
98
  minimum=1,