Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
ec0268d
1
Parent(s):
6eb225b
yes
Browse files
app.py
CHANGED
@@ -401,6 +401,91 @@ def zerogpu_generate(full_prompt,
|
|
401 |
if torch.cuda.is_available():
|
402 |
torch.cuda.empty_cache()
|
403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
# -----------------------
|
405 |
# Gradio handlers
|
406 |
# -----------------------
|
@@ -498,6 +583,21 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
498 |
except Exception as e:
|
499 |
return f"[Error] {type(e).__name__}: {str(e)}"
|
500 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
501 |
# -----------------------
|
502 |
# UI
|
503 |
# -----------------------
|
@@ -560,7 +660,14 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
560 |
)
|
561 |
|
562 |
# Chat interface - using only valid parameters
|
563 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
564 |
fn=generate_response,
|
565 |
type="messages",
|
566 |
additional_inputs=[
|
@@ -579,18 +686,18 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
579 |
cache_examples=False,
|
580 |
)
|
581 |
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
586 |
-
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
|
595 |
if __name__ == "__main__":
|
596 |
demo.queue(max_size=8 if ZEROGPU else 32).launch(
|
|
|
401 |
if torch.cuda.is_available():
|
402 |
torch.cuda.empty_cache()
|
403 |
|
404 |
+
# -----------------------
|
405 |
+
# GPU Debug: Harmony Inspector
|
406 |
+
# -----------------------
|
407 |
+
@spaces.GPU(duration=120)
|
408 |
+
def zerogpu_generate_debug(full_prompt, gen_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
409 |
+
"""Minimal GPU path to run a single prompt and return Harmony-parsed output
|
410 |
+
along with short token previews for debugging. Does not use Rose for clarity."""
|
411 |
+
model = None
|
412 |
+
try:
|
413 |
+
model = _load_model_on("auto")
|
414 |
+
device = next(model.parameters()).device
|
415 |
+
|
416 |
+
# Prepare inputs (tokens if Harmony renderer used, else string -> encode)
|
417 |
+
if HARMONY_AVAILABLE and not isinstance(full_prompt, str):
|
418 |
+
token_list = list(full_prompt)
|
419 |
+
if not token_list:
|
420 |
+
raise ValueError("Harmony prompt produced no tokens")
|
421 |
+
input_ids = torch.tensor([token_list], dtype=torch.long, device=device)
|
422 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
|
423 |
+
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
424 |
+
prompt_len = input_ids.shape[1]
|
425 |
+
else:
|
426 |
+
enc = tokenizer(full_prompt, return_tensors="pt")
|
427 |
+
inputs = {k: v.to(device) for k, v in enc.items()}
|
428 |
+
if "attention_mask" not in inputs:
|
429 |
+
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
430 |
+
prompt_len = int(inputs["input_ids"].shape[1])
|
431 |
+
|
432 |
+
# Harmony stop via stopping criteria
|
433 |
+
sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)]) if (HARMONY_AVAILABLE and HARMONY_STOP_IDS) else None
|
434 |
+
|
435 |
+
out_ids = model.generate(
|
436 |
+
**inputs,
|
437 |
+
do_sample=bool(gen_kwargs.get("do_sample", True)),
|
438 |
+
temperature=float(gen_kwargs.get("temperature", 0.7)),
|
439 |
+
top_p=float(gen_kwargs.get("top_p", 0.9)),
|
440 |
+
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
441 |
+
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
442 |
+
pad_token_id=model.config.pad_token_id,
|
443 |
+
stopping_criteria=sc,
|
444 |
+
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
|
445 |
+
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
|
446 |
+
)
|
447 |
+
|
448 |
+
out_list = out_ids[0].tolist()
|
449 |
+
gen_ids = out_list[prompt_len:]
|
450 |
+
# Truncate at first Harmony stop token if present
|
451 |
+
if HARMONY_AVAILABLE and HARMONY_STOP_IDS:
|
452 |
+
for sid in HARMONY_STOP_IDS:
|
453 |
+
if sid in gen_ids:
|
454 |
+
gen_ids = gen_ids[:gen_ids.index(sid)]
|
455 |
+
break
|
456 |
+
|
457 |
+
# Parse channels
|
458 |
+
if HARMONY_AVAILABLE:
|
459 |
+
try:
|
460 |
+
channels = parse_harmony_response(gen_ids)
|
461 |
+
except Exception:
|
462 |
+
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
463 |
+
channels = {"final": extract_final_channel_fallback(decoded), "raw": decoded}
|
464 |
+
else:
|
465 |
+
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
466 |
+
channels = {"final": extract_final_channel_fallback(decoded), "raw": decoded}
|
467 |
+
|
468 |
+
# Small previews (avoid flooding logs/UI)
|
469 |
+
preview = {
|
470 |
+
"prompt_len": int(prompt_len),
|
471 |
+
"stop_ids": list(HARMONY_STOP_IDS) if HARMONY_AVAILABLE else [],
|
472 |
+
"gen_len": int(len(gen_ids)),
|
473 |
+
"gen_ids_head": gen_ids[:48],
|
474 |
+
"decoded_head": tokenizer.decode(gen_ids[:256], skip_special_tokens=False),
|
475 |
+
"channels": channels,
|
476 |
+
}
|
477 |
+
return preview
|
478 |
+
except Exception as e:
|
479 |
+
return {"error": f"{type(e).__name__}: {e}"}
|
480 |
+
finally:
|
481 |
+
try:
|
482 |
+
del model
|
483 |
+
except Exception:
|
484 |
+
pass
|
485 |
+
gc.collect()
|
486 |
+
if torch.cuda.is_available():
|
487 |
+
torch.cuda.empty_cache()
|
488 |
+
|
489 |
# -----------------------
|
490 |
# Gradio handlers
|
491 |
# -----------------------
|
|
|
583 |
except Exception as e:
|
584 |
return f"[Error] {type(e).__name__}: {str(e)}"
|
585 |
|
586 |
+
# -----------------------
|
587 |
+
# Extra handler: Harmony Inspector wrapper
|
588 |
+
# -----------------------
|
589 |
+
|
590 |
+
def harmony_inspect_handler(user_prompt: str, system_prompt: str, reasoning_effort: str):
|
591 |
+
try:
|
592 |
+
msgs = [{"role": "system", "content": system_prompt or SYSTEM_DEF}, {"role": "user", "content": user_prompt or "What is 2+2?"}]
|
593 |
+
prompt = create_harmony_prompt(msgs, reasoning_effort)
|
594 |
+
return zerogpu_generate_debug(
|
595 |
+
prompt,
|
596 |
+
{"do_sample": True, "temperature": 0.7, "top_p": 0.9, "top_k": 0, "max_new_tokens": MAX_DEF}
|
597 |
+
)
|
598 |
+
except Exception as e:
|
599 |
+
return {"error": f"{type(e).__name__}: {e}"}
|
600 |
+
|
601 |
# -----------------------
|
602 |
# UI
|
603 |
# -----------------------
|
|
|
660 |
)
|
661 |
|
662 |
# Chat interface - using only valid parameters
|
663 |
+
# --- Harmony Inspector UI ---
|
664 |
+
with gr.Accordion("Harmony Inspector", open=False):
|
665 |
+
debug_prompt = gr.Textbox(label="Debug prompt", value="What is 2+2? Reply with just the number.")
|
666 |
+
run_debug = gr.Button("Run Harmony Inspect")
|
667 |
+
debug_out = gr.JSON(label="Parsed Harmony output", value={})
|
668 |
+
run_debug.click(harmony_inspect_handler, inputs=[debug_prompt, system_prompt, reasoning_effort], outputs=[debug_out])
|
669 |
+
|
670 |
+
chat = gr.ChatInterface(
|
671 |
fn=generate_response,
|
672 |
type="messages",
|
673 |
additional_inputs=[
|
|
|
686 |
cache_examples=False,
|
687 |
)
|
688 |
|
689 |
+
gr.Markdown(
|
690 |
+
"""
|
691 |
+
---
|
692 |
+
### Configuration:
|
693 |
+
- **Model**: Set `MODEL_ID` env var (default: openai/gpt-oss-20b)
|
694 |
+
- **Adapter**: Set `ADAPTER_ID` and optionally `ADAPTER_SUBFOLDER`
|
695 |
+
- **Auth**: Set `HF_TOKEN` in Space secrets for private model access
|
696 |
+
- **Harmony**: Install with `pip install openai-harmony` for proper channel support
|
697 |
+
|
698 |
+
The model uses Harmony format with thinking channels (`thinking`, `analysis`, `final`).
|
699 |
+
"""
|
700 |
+
)
|
701 |
|
702 |
if __name__ == "__main__":
|
703 |
demo.queue(max_size=8 if ZEROGPU else 32).launch(
|