Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
23ca4d8
1
Parent(s):
53d9a8e
yes
Browse files
app.py
CHANGED
@@ -189,7 +189,15 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
|
|
189 |
)
|
190 |
|
191 |
convo = Conversation.from_messages(harmony_messages)
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
# Fallback: tokenizer chat template -> string prompt
|
195 |
if not messages or messages[0].get("role") != "system":
|
@@ -341,6 +349,18 @@ def zerogpu_generate(full_prompt,
|
|
341 |
sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
|
342 |
|
343 |
# Generate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
out_ids = model.generate(
|
345 |
**inputs,
|
346 |
do_sample=bool(gen_kwargs.get("do_sample", True)),
|
@@ -349,6 +369,8 @@ def zerogpu_generate(full_prompt,
|
|
349 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
350 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
351 |
pad_token_id=model.config.pad_token_id,
|
|
|
|
|
352 |
logits_processor=logits_processor,
|
353 |
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
|
354 |
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
|
@@ -437,6 +459,8 @@ def zerogpu_generate_debug(full_prompt, gen_kwargs: Dict[str, Any]) -> Dict[str,
|
|
437 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
438 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
439 |
pad_token_id=model.config.pad_token_id,
|
|
|
|
|
440 |
stopping_criteria=sc,
|
441 |
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
|
442 |
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
|
|
|
189 |
)
|
190 |
|
191 |
convo = Conversation.from_messages(harmony_messages)
|
192 |
+
rendered = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
193 |
+
# Ensure assistant header includes a final channel + message start to avoid 'assistantassistant...' loops
|
194 |
+
try:
|
195 |
+
_tail = tokenizer.decode(list(rendered)[-64:], skip_special_tokens=False)
|
196 |
+
if '<|channel|>final<|message|>' not in _tail:
|
197 |
+
rendered = list(rendered) + tokenizer.encode('<|channel|>final<|message|>', add_special_tokens=False)
|
198 |
+
except Exception:
|
199 |
+
rendered = list(rendered)
|
200 |
+
return rendered
|
201 |
|
202 |
# Fallback: tokenizer chat template -> string prompt
|
203 |
if not messages or messages[0].get("role") != "system":
|
|
|
349 |
sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
|
350 |
|
351 |
# Generate
|
352 |
+
# Disallow degenerate header loops
|
353 |
+
bad_words_ids = None
|
354 |
+
try:
|
355 |
+
_B = []
|
356 |
+
for s in ("assistantassistant", "assistant", "<|assistant|>"):
|
357 |
+
ids = tokenizer.encode(s, add_special_tokens=False)
|
358 |
+
if ids:
|
359 |
+
_B.append(ids)
|
360 |
+
bad_words_ids = _B if _B else None
|
361 |
+
except Exception:
|
362 |
+
pass
|
363 |
+
|
364 |
out_ids = model.generate(
|
365 |
**inputs,
|
366 |
do_sample=bool(gen_kwargs.get("do_sample", True)),
|
|
|
369 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
370 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
371 |
pad_token_id=model.config.pad_token_id,
|
372 |
+
eos_token_id=tokenizer.eos_token_id,
|
373 |
+
bad_words_ids=bad_words_ids,
|
374 |
logits_processor=logits_processor,
|
375 |
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
|
376 |
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
|
|
|
459 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
460 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
461 |
pad_token_id=model.config.pad_token_id,
|
462 |
+
eos_token_id=tokenizer.eos_token_id,
|
463 |
+
bad_words_ids=bad_words_ids,
|
464 |
stopping_criteria=sc,
|
465 |
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
|
466 |
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
|