雷娃 commited on
Commit
d287896
·
1 Parent(s): 2493f19

add interactive mode

Browse files
Files changed (1) hide show
  1. app.py +2 -16
app.py CHANGED
@@ -30,31 +30,17 @@ def chat(user_input, max_new_tokens=512):
30
  #create streamer
31
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
32
 
33
- def get_start_idx(response, input):
34
- match = re.search(re.escape(response), input)
35
- if not match:
36
- return -1
37
- return match.end()
38
-
39
  def generate():
40
  model.generate(**inputs, max_new_tokens=max_new_tokens, streamer=streamer)
41
 
42
  thread = Thread(target=generate)
43
  thread.start()
44
 
45
- start_idx = -1
46
  generated_text = ""
47
  for new_text in streamer:
48
  generated_text += new_text
49
-
50
- if (start_idx == -1):
51
- start_idx = get_start_idx(generated_text, user_input)
52
- if (start_idx != -1):
53
- start_idx += len("ASSISTANT")
54
- #print(generated_text)
55
- #yield generated_text
56
- if (start_idx > 0):
57
- yield generated_text[start_idx:]
58
 
59
  thread.join()
60
 
 
30
  #create streamer
31
  streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
32
 
 
 
 
 
 
 
33
  def generate():
34
  model.generate(**inputs, max_new_tokens=max_new_tokens, streamer=streamer)
35
 
36
  thread = Thread(target=generate)
37
  thread.start()
38
 
39
+ start_idx = len(inputs)
40
  generated_text = ""
41
  for new_text in streamer:
42
  generated_text += new_text
43
+ yield generated_text[start_idx:]
 
 
 
 
 
 
 
 
44
 
45
  thread.join()
46