AlexHung29629 commited on
Commit
9d50662
·
verified ·
1 Parent(s): 7f16707

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -145,19 +145,20 @@ def generate(
145
  chat_template: str,
146
  max_new_tokens: int = 1024,
147
  temperature: float = 0.6,
148
- top_p: float = 0.9,
149
  top_k: int = 50,
150
- repetition_penalty: float = 1.2,
151
  ) -> Iterator[str]:
152
  conversation = [*chat_history, {"role": "user", "content": message}]
153
 
154
- input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, enable_thinking=False, return_tensors="pt")
 
155
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
156
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
157
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
158
  input_ids = input_ids.to(model.device)
159
 
160
- streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=False)
161
  generate_kwargs = dict(
162
  {"input_ids": input_ids},
163
  streamer=streamer,
 
145
  chat_template: str,
146
  max_new_tokens: int = 1024,
147
  temperature: float = 0.6,
148
+ top_p: float = 0.95,
149
  top_k: int = 50,
150
+ repetition_penalty: float = 1.0,
151
  ) -> Iterator[str]:
152
  conversation = [*chat_history, {"role": "user", "content": message}]
153
 
154
+ #input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, enable_thinking=False, return_tensors="pt")
155
+ input_ids = tokenizer.apply_chat_template(conversation, chat_template=chat_template, return_tensors="pt")
156
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
157
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
158
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
159
  input_ids = input_ids.to(model.device)
160
 
161
+ streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=False)
162
  generate_kwargs = dict(
163
  {"input_ids": input_ids},
164
  streamer=streamer,