jmercat commited on
Commit
3bbd70b
·
1 Parent(s): 2aa2690

stop on stop token

Browse files
Files changed (1) hide show
  1. app.py +5 -1
app.py CHANGED
@@ -50,7 +50,8 @@ def generate(
50
  pad_token_id=current_tokenizer.eos_token_id
51
  )
52
 
53
- streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=True)
 
54
  generate_kwargs["streamer"] = streamer
55
 
56
  thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
@@ -61,6 +62,9 @@ def generate(
61
  for new_text in streamer:
62
  if isinstance(new_text, torch.Tensor):
63
  new_text = current_tokenizer.decode(new_text)
 
 
 
64
  output += new_text
65
  yield output
66
 
 
50
  pad_token_id=current_tokenizer.eos_token_id
51
  )
52
 
53
+ streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
54
+ streamer.stop_signal = current_tokenizer.decode(current_tokenizer.eos_token_id)
55
  generate_kwargs["streamer"] = streamer
56
 
57
  thread = Thread(target=current_model.generate, kwargs=generate_kwargs)
 
62
  for new_text in streamer:
63
  if isinstance(new_text, torch.Tensor):
64
  new_text = current_tokenizer.decode(new_text)
65
+ if streamer.stop_signal in new_text:
66
+ output += new_text.split(streamer.stop_signal)[0]
67
+ break
68
  output += new_text
69
  yield output
70