AbstractPhil commited on
Commit
51a55c1
·
1 Parent(s): 7f8b6c0
Files changed (1) hide show
  1. app.py +46 -26
app.py CHANGED
@@ -150,11 +150,17 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
150
  # Harmony formatting
151
  # -----------------------
152
 
153
- def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
154
- """Create a proper Harmony-formatted prompt using openai_harmony."""
155
- if not HARMONY_AVAILABLE:
156
- # Fallback to tokenizer's chat template
157
- return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
 
 
 
 
 
 
158
 
159
  # Map reasoning effort
160
  effort_map = {
@@ -223,24 +229,40 @@ def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
223
  return channels
224
 
225
  def extract_final_channel_fallback(text: str) -> str:
226
- """Fallback extraction when harmony library isn't available."""
227
- # Look for the final channel marker
228
- final_marker = "<|channel|>final<|message|>"
229
-
230
- if final_marker in text:
231
- parts = text.split(final_marker)
232
- if len(parts) > 1:
233
- final_text = parts[-1]
234
-
235
- # Clean up end markers
236
- end_markers = ["<|return|>", "<|end|>", "<|endoftext|>"]
237
- for marker in end_markers:
238
- if marker in final_text:
239
- final_text = final_text.split(marker)[0]
240
-
241
- return final_text.strip()
242
-
243
- # If no channel markers found, return cleaned text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
  return text.strip()
245
 
246
  # -----------------------
@@ -319,9 +341,7 @@ def zerogpu_generate(full_prompt,
319
  top_p=float(gen_kwargs.get("top_p", 0.9)),
320
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
321
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
322
- pad_token_id=model.config.pad_token_id,
323
- eos_token_id=eos_ids,
324
- logits_processor=logits_processor,
325
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
326
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
327
  min_new_tokens=1,
 
150
  # Harmony formatting
151
  # -----------------------
152
 
153
+ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> str:
154
+ """Build Harmony-formatted prompt using the *tokenizer chat template* (per model card).
155
+ Always returns a string; HF will tokenize to ensure IDs match the checkpoint.
156
+ """
157
+ if not messages or messages[0].get("role") != "system":
158
+ messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
159
+ return tokenizer.apply_chat_template(
160
+ messages,
161
+ add_generation_prompt=True,
162
+ tokenize=False
163
+ )
164
 
165
  # Map reasoning effort
166
  effort_map = {
 
229
  return channels
230
 
231
  def extract_final_channel_fallback(text: str) -> str:
232
+ """Robustly extract the <final> channel from decoded Harmony text.
233
+ Works even if parsing fails or the model emits extra headers.
234
+ """
235
+ try:
236
+ chunks: Dict[str, str] = {}
237
+ pieces = text.split("<|channel|>")
238
+ for seg in pieces[1:]:
239
+ name_end = seg.find("<|message|>")
240
+ if name_end <= 0:
241
+ continue
242
+ ch = seg[:name_end].strip()
243
+ body_start = name_end + len("<|message|>")
244
+ # end at next channel/end/return marker
245
+ next_pos = len(seg)
246
+ for delim in ("<|channel|>", "<|end|>", "<|return|>"):
247
+ p = seg.find(delim, body_start)
248
+ if p != -1:
249
+ next_pos = min(next_pos, p)
250
+ body = seg[body_start:next_pos]
251
+ chunks[ch] = chunks.get(ch, "") + body
252
+ final_txt = (chunks.get("final", "").strip())
253
+ if final_txt:
254
+ return final_txt
255
+ # Fallback: everything after last final marker up to a terminator
256
+ if "<|channel|>final<|message|>" in text:
257
+ tail = text.split("<|channel|>final<|message|>")[-1]
258
+ for delim in ("<|return|>", "<|end|>", "<|channel|>"):
259
+ idx = tail.find(delim)
260
+ if idx != -1:
261
+ tail = tail[:idx]
262
+ break
263
+ return tail.strip()
264
+ except Exception:
265
+ pass
266
  return text.strip()
267
 
268
  # -----------------------
 
341
  top_p=float(gen_kwargs.get("top_p", 0.9)),
342
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
343
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
344
+ pad_token_id=model.config.pad_token_id, logits_processor=logits_processor,
 
 
345
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
346
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
347
  min_new_tokens=1,