Spaces:
Runtime error
Runtime error
Update main.py
Browse files
main.py
CHANGED
@@ -54,8 +54,10 @@ async def generate_response_chunks(prompt: str):
|
|
54 |
else:
|
55 |
input_ids = new_input_ids
|
56 |
|
|
|
57 |
output_ids = model.generate(
|
58 |
input_ids,
|
|
|
59 |
max_new_tokens=200,
|
60 |
do_sample=True,
|
61 |
top_p=0.9,
|
@@ -63,6 +65,7 @@ async def generate_response_chunks(prompt: str):
|
|
63 |
pad_token_id=tokenizer.eos_token_id
|
64 |
)
|
65 |
|
|
|
66 |
chat_history_ids = output_ids # update history
|
67 |
|
68 |
response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|
|
|
54 |
else:
|
55 |
input_ids = new_input_ids
|
56 |
|
57 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id).long()
|
58 |
output_ids = model.generate(
|
59 |
input_ids,
|
60 |
+
attention_mask=attention_mask,
|
61 |
max_new_tokens=200,
|
62 |
do_sample=True,
|
63 |
top_p=0.9,
|
|
|
65 |
pad_token_id=tokenizer.eos_token_id
|
66 |
)
|
67 |
|
68 |
+
|
69 |
chat_history_ids = output_ids # update history
|
70 |
|
71 |
response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
|