Spaces:
Building
on
Zero
Building
on
Zero
AbstractPhil
commited on
Commit
·
51a55c1
1
Parent(s):
7f8b6c0
yes
Browse files
app.py
CHANGED
@@ -150,11 +150,17 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
|
|
150 |
# Harmony formatting
|
151 |
# -----------------------
|
152 |
|
153 |
-
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") ->
|
154 |
-
"""
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
# Map reasoning effort
|
160 |
effort_map = {
|
@@ -223,24 +229,40 @@ def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
|
223 |
return channels
|
224 |
|
225 |
def extract_final_channel_fallback(text: str) -> str:
|
226 |
-
"""
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
return text.strip()
|
245 |
|
246 |
# -----------------------
|
@@ -319,9 +341,7 @@ def zerogpu_generate(full_prompt,
|
|
319 |
top_p=float(gen_kwargs.get("top_p", 0.9)),
|
320 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
321 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
322 |
-
pad_token_id=model.config.pad_token_id,
|
323 |
-
eos_token_id=eos_ids,
|
324 |
-
logits_processor=logits_processor,
|
325 |
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
|
326 |
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
|
327 |
min_new_tokens=1,
|
|
|
150 |
# Harmony formatting
|
151 |
# -----------------------
|
152 |
|
153 |
+
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> str:
|
154 |
+
"""Build Harmony-formatted prompt using the *tokenizer chat template* (per model card).
|
155 |
+
Always returns a string; HF will tokenize to ensure IDs match the checkpoint.
|
156 |
+
"""
|
157 |
+
if not messages or messages[0].get("role") != "system":
|
158 |
+
messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
|
159 |
+
return tokenizer.apply_chat_template(
|
160 |
+
messages,
|
161 |
+
add_generation_prompt=True,
|
162 |
+
tokenize=False
|
163 |
+
)
|
164 |
|
165 |
# Map reasoning effort
|
166 |
effort_map = {
|
|
|
229 |
return channels
|
230 |
|
231 |
def extract_final_channel_fallback(text: str) -> str:
|
232 |
+
"""Robustly extract the <final> channel from decoded Harmony text.
|
233 |
+
Works even if parsing fails or the model emits extra headers.
|
234 |
+
"""
|
235 |
+
try:
|
236 |
+
chunks: Dict[str, str] = {}
|
237 |
+
pieces = text.split("<|channel|>")
|
238 |
+
for seg in pieces[1:]:
|
239 |
+
name_end = seg.find("<|message|>")
|
240 |
+
if name_end <= 0:
|
241 |
+
continue
|
242 |
+
ch = seg[:name_end].strip()
|
243 |
+
body_start = name_end + len("<|message|>")
|
244 |
+
# end at next channel/end/return marker
|
245 |
+
next_pos = len(seg)
|
246 |
+
for delim in ("<|channel|>", "<|end|>", "<|return|>"):
|
247 |
+
p = seg.find(delim, body_start)
|
248 |
+
if p != -1:
|
249 |
+
next_pos = min(next_pos, p)
|
250 |
+
body = seg[body_start:next_pos]
|
251 |
+
chunks[ch] = chunks.get(ch, "") + body
|
252 |
+
final_txt = (chunks.get("final", "").strip())
|
253 |
+
if final_txt:
|
254 |
+
return final_txt
|
255 |
+
# Fallback: everything after last final marker up to a terminator
|
256 |
+
if "<|channel|>final<|message|>" in text:
|
257 |
+
tail = text.split("<|channel|>final<|message|>")[-1]
|
258 |
+
for delim in ("<|return|>", "<|end|>", "<|channel|>"):
|
259 |
+
idx = tail.find(delim)
|
260 |
+
if idx != -1:
|
261 |
+
tail = tail[:idx]
|
262 |
+
break
|
263 |
+
return tail.strip()
|
264 |
+
except Exception:
|
265 |
+
pass
|
266 |
return text.strip()
|
267 |
|
268 |
# -----------------------
|
|
|
341 |
top_p=float(gen_kwargs.get("top_p", 0.9)),
|
342 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
343 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
344 |
+
pad_token_id=model.config.pad_token_id, logits_processor=logits_processor,
|
|
|
|
|
345 |
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
|
346 |
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
|
347 |
min_new_tokens=1,
|