AbstractPhil commited on
Commit
23ca4d8
·
1 Parent(s): 53d9a8e
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -189,7 +189,15 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
189
  )
190
 
191
  convo = Conversation.from_messages(harmony_messages)
192
- return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
 
 
 
 
 
 
 
 
193
 
194
  # Fallback: tokenizer chat template -> string prompt
195
  if not messages or messages[0].get("role") != "system":
@@ -341,6 +349,18 @@ def zerogpu_generate(full_prompt,
341
  sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
342
 
343
  # Generate
 
 
 
 
 
 
 
 
 
 
 
 
344
  out_ids = model.generate(
345
  **inputs,
346
  do_sample=bool(gen_kwargs.get("do_sample", True)),
@@ -349,6 +369,8 @@ def zerogpu_generate(full_prompt,
349
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
350
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
351
  pad_token_id=model.config.pad_token_id,
 
 
352
  logits_processor=logits_processor,
353
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
354
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
@@ -437,6 +459,8 @@ def zerogpu_generate_debug(full_prompt, gen_kwargs: Dict[str, Any]) -> Dict[str,
437
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
438
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
439
  pad_token_id=model.config.pad_token_id,
 
 
440
  stopping_criteria=sc,
441
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
442
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
 
189
  )
190
 
191
  convo = Conversation.from_messages(harmony_messages)
192
+ rendered = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
193
+ # Ensure assistant header includes a final channel + message start to avoid 'assistantassistant...' loops
194
+ try:
195
+ _tail = tokenizer.decode(list(rendered)[-64:], skip_special_tokens=False)
196
+ if '<|channel|>final<|message|>' not in _tail:
197
+ rendered = list(rendered) + tokenizer.encode('<|channel|>final<|message|>', add_special_tokens=False)
198
+ except Exception:
199
+ rendered = list(rendered)
200
+ return rendered
201
 
202
  # Fallback: tokenizer chat template -> string prompt
203
  if not messages or messages[0].get("role") != "system":
 
349
  sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
350
 
351
  # Generate
352
+ # Disallow degenerate header loops
353
+ bad_words_ids = None
354
+ try:
355
+ _B = []
356
+ for s in ("assistantassistant", "assistant", "<|assistant|>"):
357
+ ids = tokenizer.encode(s, add_special_tokens=False)
358
+ if ids:
359
+ _B.append(ids)
360
+ bad_words_ids = _B if _B else None
361
+ except Exception:
362
+ pass
363
+
364
  out_ids = model.generate(
365
  **inputs,
366
  do_sample=bool(gen_kwargs.get("do_sample", True)),
 
369
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
370
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
371
  pad_token_id=model.config.pad_token_id,
372
+ eos_token_id=tokenizer.eos_token_id,
373
+ bad_words_ids=bad_words_ids,
374
  logits_processor=logits_processor,
375
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
376
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
 
459
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
460
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
461
  pad_token_id=model.config.pad_token_id,
462
+ eos_token_id=tokenizer.eos_token_id,
463
+ bad_words_ids=bad_words_ids,
464
  stopping_criteria=sc,
465
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
466
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),