prithivMLmods commited on
Commit
68abcc8
Β·
verified Β·
1 Parent(s): 4d0dad8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -137,15 +137,23 @@ def generate(
137
  yield buffer
138
  else:
139
  # Text-only input
140
- conversation = chat_history.copy()
 
 
 
 
 
 
141
  conversation.append({"role": "user", "content": message})
142
 
 
143
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
144
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
145
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
146
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
147
  input_ids = input_ids.to(model.device)
148
 
 
149
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
150
  generate_kwargs = dict(
151
  {"input_ids": input_ids},
@@ -223,4 +231,4 @@ demo = gr.ChatInterface(
223
  )
224
 
225
  if __name__ == "__main__":
226
- demo.queue(max_size=20).launch()
 
137
  yield buffer
138
  else:
139
  # Text-only input
140
+ # Ensure the chat history alternates between user and assistant roles
141
+ conversation = []
142
+ for i, entry in enumerate(chat_history):
143
+ if i % 2 == 0:
144
+ conversation.append({"role": "user", "content": entry["content"]})
145
+ else:
146
+ conversation.append({"role": "assistant", "content": entry["content"]})
147
  conversation.append({"role": "user", "content": message})
148
 
149
+ # Apply the chat template
150
  input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
151
  if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
152
  input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
153
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
154
  input_ids = input_ids.to(model.device)
155
 
156
+ # Stream the output
157
  streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
158
  generate_kwargs = dict(
159
  {"input_ids": input_ids},
 
231
  )
232
 
233
  if __name__ == "__main__":
234
+ demo.queue(max_size=20).launch(share=True)