AbstractPhil commited on
Commit
40b9211
·
1 Parent(s): ae231bc

probably works-ish

Browse files
Files changed (1) hide show
  1. app.py +123 -28
app.py CHANGED
@@ -73,8 +73,8 @@ def _hf_login() -> None:
73
  else:
74
  print("[HF Auth] No token found in environment variables")
75
 
76
- # Login before loading any models
77
- _hf_login()
78
 
79
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
80
 
@@ -364,13 +364,18 @@ def zerogpu_generate(full_prompt,
364
  out_ids = model.generate(
365
  **inputs,
366
  do_sample=bool(gen_kwargs.get("do_sample", True)),
367
- temperature=float(gen_kwargs.get("temperature", 0.7)),
368
- top_p=float(gen_kwargs.get("top_p", 0.9)),
369
- top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
370
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
371
  pad_token_id=model.config.pad_token_id,
372
  eos_token_id=tokenizer.eos_token_id,
373
- bad_words_ids=bad_words_ids,
 
 
 
 
 
374
  logits_processor=logits_processor,
375
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
376
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
@@ -420,6 +425,59 @@ def zerogpu_generate(full_prompt,
420
  if torch.cuda.is_available():
421
  torch.cuda.empty_cache()
422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
423
  # -----------------------
424
  # GPU Debug: Harmony Inspector
425
  # -----------------------
@@ -460,7 +518,6 @@ def zerogpu_generate_debug(full_prompt, gen_kwargs: Dict[str, Any]) -> Dict[str,
460
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
461
  pad_token_id=model.config.pad_token_id,
462
  eos_token_id=tokenizer.eos_token_id,
463
- bad_words_ids=bad_words_ids,
464
  stopping_criteria=sc,
465
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
466
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
@@ -517,29 +574,45 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
517
  rose_enable: bool, rose_alpha: float, rose_score: Optional[float],
518
  rose_tokens: str, rose_json: str,
519
  show_thinking: bool = False,
 
520
  reasoning_effort: str = "high") -> str:
521
  """
522
  Generate response with proper CoT handling using Harmony format.
523
  """
524
  try:
525
- # Build message list
526
  messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
527
-
528
- # Add history
529
  if history:
530
- for turn in history:
531
- if isinstance(turn, (list, tuple)) and len(turn) >= 2:
532
- user_msg, assistant_msg = turn[0], turn[1]
533
- if user_msg:
534
- messages.append({"role": "user", "content": str(user_msg)})
535
- if assistant_msg:
536
- messages.append({"role": "assistant", "content": str(assistant_msg)})
537
-
538
- # Add current message
539
- messages.append({"role": "user", "content": str(message)})
540
-
541
- # Create Harmony-formatted prompt
542
- if HARMONY_AVAILABLE:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  prompt = create_harmony_prompt(messages, reasoning_effort) # returns token IDs
544
  else:
545
  # Fallback to tokenizer template (string)
@@ -573,7 +646,23 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
573
  rose_map = None
574
 
575
  # Generate with model
576
- channels = zerogpu_generate(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
  prompt,
578
  {
579
  "do_sample": bool(do_sample),
@@ -641,7 +730,13 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
641
  lines=2
642
  )
643
 
644
- with gr.Accordion("Generation Settings", open=False):
 
 
 
 
 
 
645
  with gr.Row():
646
  temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
647
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
@@ -692,9 +787,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
692
  fn=generate_response,
693
  type="messages",
694
  additional_inputs=[
695
- system_prompt, temperature, top_p, top_k, max_new,
696
- do_sample, seed, rose_enable, rose_alpha, rose_score,
697
- rose_tokens, rose_json, show_thinking, reasoning_effort
698
  ],
699
  title="Chat with Mirel",
700
  description="A chain-of-thought model using Harmony format",
 
73
  else:
74
  print("[HF Auth] No token found in environment variables")
75
 
76
+ # Login is handled by Space OAuth/session; avoid explicit CLI login here to prevent OAuth var errors
77
+ # _hf_login()
78
 
79
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
80
 
 
364
  out_ids = model.generate(
365
  **inputs,
366
  do_sample=bool(gen_kwargs.get("do_sample", True)),
367
+ temperature=float(gen_kwargs.get("temperature", 0.6)),
368
+ top_p=(float(gen_kwargs.get("top_p")) if gen_kwargs.get("top_p") is not None else None),
369
+ top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") else None),
370
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
371
  pad_token_id=model.config.pad_token_id,
372
  eos_token_id=tokenizer.eos_token_id,
373
+ repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
374
+ no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
375
+ logits_processor=logits_processor,
376
+ )
377
+ eos_token_id=tokenizer.eos_token_id,
378
+
379
  logits_processor=logits_processor,
380
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
381
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
 
425
  if torch.cuda.is_available():
426
  torch.cuda.empty_cache()
427
 
428
+ # -----------------------
429
+ # Simple (non-Harmony) GPU path — matches your minimal example
430
+ # -----------------------
431
+ @spaces.GPU(duration=120)
432
+ def zerogpu_generate_simple(prompt_str: str, gen_kwargs: Dict[str, Any], rose_map: Optional[Dict[str, float]], rose_alpha: float, rose_score: Optional[float], seed: Optional[int]) -> Dict[str, str]:
433
+ """Straight chat_template path. No Harmony tokens. Slices completion from prompt_len.
434
+ Mirrors the minimal HF example and avoids header loops entirely."""
435
+ model = None
436
+ try:
437
+ if seed is not None:
438
+ torch.manual_seed(int(seed))
439
+ model = _load_model_on("auto")
440
+ device = next(model.parameters()).device
441
+
442
+ # Encode prompt string
443
+ enc = tokenizer(prompt_str, return_tensors="pt")
444
+ inputs = {k: v.to(device) for k, v in enc.items()}
445
+ prompt_len = int(inputs["input_ids"].shape[1])
446
+ if "attention_mask" not in inputs:
447
+ inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
448
+
449
+ # Optional Rose bias
450
+ logits_processor = None
451
+ if rose_map:
452
+ bias = build_bias_from_tokens(tokenizer, rose_map).to(device)
453
+ eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
454
+ logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
455
+
456
+ out_ids = model.generate(
457
+ **inputs,
458
+ do_sample=bool(gen_kwargs.get("do_sample", True)),
459
+ temperature=float(gen_kwargs.get("temperature", 0.6)),
460
+ top_p=(float(gen_kwargs.get("top_p")) if gen_kwargs.get("top_p") is not None else None),
461
+ top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") else None),
462
+ max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
463
+ pad_token_id=model.config.pad_token_id,
464
+ logits_processor=logits_processor,
465
+ )
466
+ # Slice generated continuation only
467
+ new_ids = out_ids[0, prompt_len:]
468
+ text = tokenizer.decode(new_ids, skip_special_tokens=True)
469
+ return {"final": text}
470
+ except Exception as e:
471
+ return {"final": f"[Error] {type(e).__name__}: {e}"}
472
+ finally:
473
+ try:
474
+ del model
475
+ except Exception:
476
+ pass
477
+ gc.collect()
478
+ if torch.cuda.is_available():
479
+ torch.cuda.empty_cache()
480
+
481
  # -----------------------
482
  # GPU Debug: Harmony Inspector
483
  # -----------------------
 
518
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
519
  pad_token_id=model.config.pad_token_id,
520
  eos_token_id=tokenizer.eos_token_id,
 
521
  stopping_criteria=sc,
522
  repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.15)),
523
  no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
 
574
  rose_enable: bool, rose_alpha: float, rose_score: Optional[float],
575
  rose_tokens: str, rose_json: str,
576
  show_thinking: bool = False,
577
+ simple_mode: bool = True, # NEW: default to simple chat_template path
578
  reasoning_effort: str = "high") -> str:
579
  """
580
  Generate response with proper CoT handling using Harmony format.
581
  """
582
  try:
583
+ # Build messages robustly for Gradio type='messages' or legacy tuple format
584
  messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
585
+
586
+ # Add prior turns
587
  if history:
588
+ if isinstance(history, list) and history and isinstance(history[0], dict):
589
+ # history is already a flat list of {'role','content'} dicts
590
+ for m in history:
591
+ role = m.get("role")
592
+ content = m.get("content", "")
593
+ if role in ("user", "assistant"):
594
+ messages.append({"role": role, "content": str(content)})
595
+ else:
596
+ for turn in history:
597
+ if isinstance(turn, (list, tuple)) and len(turn) >= 2:
598
+ u, a = turn[0], turn[1]
599
+ if u:
600
+ messages.append({"role": "user", "content": str(u)})
601
+ if a:
602
+ messages.append({"role": "assistant", "content": str(a)})
603
+
604
+ # Current user message
605
+ if isinstance(message, dict):
606
+ user_text = message.get("content", "")
607
+ else:
608
+ user_text = str(message)
609
+ messages.append({"role": "user", "content": user_text})
610
+
611
+ # FAST PATH: simple chat_template prompt (recommended)
612
+ if simple_mode:
613
+ prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
614
+ # Harmony path (optional)
615
+ elif HARMONY_AVAILABLE:
616
  prompt = create_harmony_prompt(messages, reasoning_effort) # returns token IDs
617
  else:
618
  # Fallback to tokenizer template (string)
 
646
  rose_map = None
647
 
648
  # Generate with model
649
+ if simple_mode:
650
+ channels = zerogpu_generate_simple(
651
+ prompt,
652
+ {
653
+ "do_sample": bool(do_sample),
654
+ "temperature": float(temperature),
655
+ "top_p": float(top_p) if top_p is not None else None,
656
+ "top_k": int(top_k) if top_k > 0 else None,
657
+ "max_new_tokens": int(max_new_tokens),
658
+ },
659
+ rose_map,
660
+ float(rose_alpha),
661
+ float(rose_score) if rose_score is not None else None,
662
+ int(seed) if seed is not None else None,
663
+ )
664
+ else:
665
+ channels = zerogpu_generate(
666
  prompt,
667
  {
668
  "do_sample": bool(do_sample),
 
730
  lines=2
731
  )
732
 
733
+ with gr.Accordion("Generation Settings ", open=False):
734
+ # NEW: toggle to bypass Harmony and use plain chat_template like your minimal script
735
+ simple_mode = gr.Checkbox(
736
+ value=True,
737
+ label="Use simple chat_template (no Harmony)",
738
+ info="Matches the minimal HF example; safest path for now"
739
+ )
740
  with gr.Row():
741
  temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
742
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
 
787
  fn=generate_response,
788
  type="messages",
789
  additional_inputs=[
790
+ system_prompt, temperature, top_p, top_k, max_new,
791
+ do_sample, seed, rose_enable, rose_alpha, rose_score,
792
+ rose_tokens, rose_json, show_thinking, simple_mode, reasoning_effort
793
  ],
794
  title="Chat with Mirel",
795
  description="A chain-of-thought model using Harmony format",