Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -137,15 +137,23 @@ def generate(
|
|
137 |
yield buffer
|
138 |
else:
|
139 |
# Text-only input
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
conversation.append({"role": "user", "content": message})
|
142 |
|
|
|
143 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
144 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
145 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
146 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
147 |
input_ids = input_ids.to(model.device)
|
148 |
|
|
|
149 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
150 |
generate_kwargs = dict(
|
151 |
{"input_ids": input_ids},
|
@@ -223,4 +231,4 @@ demo = gr.ChatInterface(
|
|
223 |
)
|
224 |
|
225 |
if __name__ == "__main__":
|
226 |
-
demo.queue(max_size=20).launch()
|
|
|
137 |
yield buffer
|
138 |
else:
|
139 |
# Text-only input
|
140 |
+
# Ensure the chat history alternates between user and assistant roles
|
141 |
+
conversation = []
|
142 |
+
for i, entry in enumerate(chat_history):
|
143 |
+
if i % 2 == 0:
|
144 |
+
conversation.append({"role": "user", "content": entry["content"]})
|
145 |
+
else:
|
146 |
+
conversation.append({"role": "assistant", "content": entry["content"]})
|
147 |
conversation.append({"role": "user", "content": message})
|
148 |
|
149 |
+
# Apply the chat template
|
150 |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
151 |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
152 |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
153 |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
154 |
input_ids = input_ids.to(model.device)
|
155 |
|
156 |
+
# Stream the output
|
157 |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
|
158 |
generate_kwargs = dict(
|
159 |
{"input_ids": input_ids},
|
|
|
231 |
)
|
232 |
|
233 |
if __name__ == "__main__":
|
234 |
+
demo.queue(max_size=20).launch(share=True)
|