AbstractPhil commited on
Commit
6eb225b
·
1 Parent(s): 3efceb8
Files changed (1) hide show
  1. app.py +27 -10
app.py CHANGED
@@ -11,7 +11,7 @@ from typing import List, Dict, Optional, Any
11
  from datetime import datetime
12
  import gradio as gr
13
  import spaces # required for ZeroGPU
14
- from transformers import AutoTokenizer, AutoModelForCausalLM
15
 
16
  # Import Harmony components
17
  try:
@@ -47,7 +47,7 @@ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
47
  LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
 
49
  # Harmony channels for CoT
50
- REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
51
 
52
  # HF Auth - properly handle multiple token env var names
53
  HF_TOKEN: Optional[str] = (
@@ -286,6 +286,12 @@ class RoseGuidedLogits(torch.nn.Module):
286
  def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
287
  return scores + self.alpha * self.bias_vec.to(scores.device)
288
 
 
 
 
 
 
 
289
  @spaces.GPU(duration=120)
290
  def zerogpu_generate(full_prompt,
291
  gen_kwargs: Dict[str, Any],
@@ -310,18 +316,30 @@ def zerogpu_generate(full_prompt,
310
 
311
  # Tokenize / prepare inputs
312
  device = next(model.parameters()).device
313
- if HARMONY_AVAILABLE and isinstance(full_prompt, list):
314
- input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
 
 
 
 
 
 
 
315
  attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
316
  inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
317
  prompt_len = input_ids.shape[1]
318
  else:
319
  enc = tokenizer(full_prompt, return_tensors="pt")
320
- inputs = enc.to(device)
321
  prompt_len = int(inputs["input_ids"].shape[1])
322
- # Guarantee attention_mask exists; avoids pad==eos ambiguity warnings
323
  if "attention_mask" not in inputs:
324
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
 
 
 
 
 
 
325
  # Generate
326
  # Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
327
  eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
@@ -334,11 +352,10 @@ def zerogpu_generate(full_prompt,
334
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
335
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
336
  pad_token_id=model.config.pad_token_id,
337
- eos_token_id=eos_ids,
338
  logits_processor=logits_processor,
339
- repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
340
- no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
341
- min_new_tokens=1,
342
  )
343
 
344
  # Extract generated tokens only
 
11
  from datetime import datetime
12
  import gradio as gr
13
  import spaces # required for ZeroGPU
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
15
 
16
  # Import Harmony components
17
  try:
 
47
  LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
 
49
  # Harmony channels for CoT
50
+ REQUIRED_CHANNELS = ["analysis", "final"]
51
 
52
  # HF Auth - properly handle multiple token env var names
53
  HF_TOKEN: Optional[str] = (
 
286
  def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
287
  return scores + self.alpha * self.bias_vec.to(scores.device)
288
 
289
+ class StopOnTokens(StoppingCriteria):
290
+ def __init__(self, stop_ids: List[int]):
291
+ self.stop_ids = set(int(s) for s in (stop_ids or []))
292
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
293
+ return int(input_ids[0, -1]) in self.stop_ids
294
+
295
  @spaces.GPU(duration=120)
296
  def zerogpu_generate(full_prompt,
297
  gen_kwargs: Dict[str, Any],
 
316
 
317
  # Tokenize / prepare inputs
318
  device = next(model.parameters()).device
319
+ if HARMONY_AVAILABLE and not isinstance(full_prompt, str):
320
+ # Accept list/tuple or any iterable of ints from openai_harmony
321
+ try:
322
+ token_list = list(full_prompt)
323
+ except TypeError:
324
+ token_list = list(getattr(full_prompt, "ids", getattr(full_prompt, "token_ids", [])))
325
+ if not token_list:
326
+ raise ValueError("Harmony prompt produced no tokens")
327
+ input_ids = torch.tensor([token_list], dtype=torch.long, device=device)
328
  attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
329
  inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
330
  prompt_len = input_ids.shape[1]
331
  else:
332
  enc = tokenizer(full_prompt, return_tensors="pt")
333
+ inputs = {k: v.to(device) for k, v in enc.items()}
334
  prompt_len = int(inputs["input_ids"].shape[1])
 
335
  if "attention_mask" not in inputs:
336
  inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
337
+
338
+ # Prepare stopping
339
+ sc = None
340
+ if HARMONY_AVAILABLE and HARMONY_STOP_IDS:
341
+ sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
342
+
343
  # Generate
344
  # Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
345
  eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
 
352
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
353
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
354
  pad_token_id=model.config.pad_token_id,
 
355
  logits_processor=logits_processor,
356
+ repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
357
+ no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
358
+ stopping_criteria=sc,
359
  )
360
 
361
  # Extract generated tokens only