prithivMLmods commited on
Commit
45359c2
Β·
verified Β·
1 Parent(s): 997fdad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -55,15 +55,17 @@ def generate(
55
  conversation = chat_history.copy()
56
  conversation.append({"role": "user", "content": message})
57
 
 
58
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
59
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
60
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
61
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
62
  input_ids = input_ids.to(model.device)
63
 
 
64
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
65
  generate_kwargs = dict(
66
- {"input_ids": input_ids},
67
  streamer=streamer,
68
  max_new_tokens=max_new_tokens,
69
  do_sample=True,
@@ -73,14 +75,20 @@ def generate(
73
  num_beams=1,
74
  repetition_penalty=repetition_penalty,
75
  )
 
 
76
  t = Thread(target=model.generate, kwargs=generate_kwargs)
77
  t.start()
78
 
 
79
  outputs = []
80
  for text in streamer:
81
  outputs.append(text)
82
  yield "".join(outputs)
83
 
 
 
 
84
  # Generate the knowledge graph HTML file
85
  generate_knowledge_graph()
86
 
@@ -129,7 +137,8 @@ demo = gr.ChatInterface(
129
  ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
130
  ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
131
  ["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
132
- ["How Many R's in the Word 'STRAWBERRY' ?"],
 
133
  ],
134
  cache_examples=False,
135
  type="messages",
@@ -139,4 +148,4 @@ demo = gr.ChatInterface(
139
  )
140
 
141
  if __name__ == "__main__":
142
- demo.queue(max_size=20).launch()
 
55
  conversation = chat_history.copy()
56
  conversation.append({"role": "user", "content": message})
57
 
58
+ # Tokenize the input
59
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
60
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
61
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
62
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
63
  input_ids = input_ids.to(model.device)
64
 
65
+ # Set up the streamer
66
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
67
  generate_kwargs = dict(
68
+ input_ids=input_ids,
69
  streamer=streamer,
70
  max_new_tokens=max_new_tokens,
71
  do_sample=True,
 
75
  num_beams=1,
76
  repetition_penalty=repetition_penalty,
77
  )
78
+
79
+ # Start the generation in a separate thread
80
  t = Thread(target=model.generate, kwargs=generate_kwargs)
81
  t.start()
82
 
83
+ # Stream the output
84
  outputs = []
85
  for text in streamer:
86
  outputs.append(text)
87
  yield "".join(outputs)
88
 
89
+ # Ensure the thread is joined after completion
90
+ t.join()
91
+
92
  # Generate the knowledge graph HTML file
93
  generate_knowledge_graph()
94
 
 
137
  ["Write a Python function to reverses a string if it's length is a multiple of 4. def reverse_string(str1): if len(str1) % 4 == 0: return ''.join(reversed(str1)) return str1 print(reverse_string('abcd')) print(reverse_string('python')) "],
138
  ["Rectangle $ABCD$ is the base of pyramid $PABCD$. If $AB = 10$, $BC = 5$, $\overline{PA}\perp \text{plane } ABCD$, and $PA = 8$, then what is the volume of $PABCD$?"],
139
  ["Difference between List comprehension and Lambda in Python lst = [x ** 2 for x in range (1, 11) if x % 2 == 1] print(lst)"],
140
+ ["How many hours does it take a man to eat a Helicopter?"],
141
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
142
  ],
143
  cache_examples=False,
144
  type="messages",
 
148
  )
149
 
150
  if __name__ == "__main__":
151
+ demo.queue(max_size=20).launch(share=True) # Set share=True for a public link