AbstractPhil commited on
Commit
1b343d0
·
1 Parent(s): 7d19f11

okay it works better now

Browse files
Files changed (2) hide show
  1. app.py +37 -16
  2. requirements.txt +2 -1
app.py CHANGED
@@ -47,7 +47,7 @@ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
47
  LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
 
49
  # Harmony channels for CoT
50
- REQUIRED_CHANNELS = ["thinking", "analysis", "final"]
51
 
52
  # HF Auth - properly handle multiple token env var names
53
  HF_TOKEN: Optional[str] = (
@@ -84,6 +84,9 @@ if HARMONY_AVAILABLE:
84
  else:
85
  harmony_encoding = None
86
 
 
 
 
87
  # Tokenizer is lightweight; load once
88
  try:
89
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
@@ -144,7 +147,7 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
144
  # Harmony formatting
145
  # -----------------------
146
 
147
- def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> str:
148
  """Create a proper Harmony-formatted prompt using openai_harmony."""
149
  if not HARMONY_AVAILABLE:
150
  # Fallback to tokenizer's chat template
@@ -189,9 +192,7 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
189
  # Create conversation and render
190
  convo = Conversation.from_messages(harmony_messages)
191
  tokens = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
192
-
193
- # Convert tokens back to text for the model
194
- return tokenizer.decode(tokens)
195
 
196
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
197
  """Parse response tokens using Harmony format to extract channels."""
@@ -268,7 +269,7 @@ class RoseGuidedLogits(torch.nn.Module):
268
  return scores + self.alpha * self.bias_vec.to(scores.device)
269
 
270
  @spaces.GPU(duration=120)
271
- def zerogpu_generate(full_prompt: str,
272
  gen_kwargs: Dict[str, Any],
273
  rose_map: Optional[Dict[str, float]],
274
  rose_alpha: float,
@@ -289,8 +290,16 @@ def zerogpu_generate(full_prompt: str,
289
  eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
290
  logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
291
 
292
- # Tokenize input
293
- inputs = tokenizer(full_prompt, return_tensors="pt").to(next(model.parameters()).device)
 
 
 
 
 
 
 
 
294
 
295
  # Generate
296
  out_ids = model.generate(
@@ -301,21 +310,33 @@ def zerogpu_generate(full_prompt: str,
301
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
302
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
303
  pad_token_id=tokenizer.eos_token_id,
304
- eos_token_id=tokenizer.eos_token_id,
305
  logits_processor=logits_processor,
306
  )
307
 
308
  # Extract generated tokens only
309
- prompt_len = int(inputs["input_ids"].shape[1])
310
- gen_ids = out_ids[0][prompt_len:].tolist()
 
 
 
 
 
 
311
 
312
  # Parse response with Harmony
313
  if HARMONY_AVAILABLE:
314
- channels = parse_harmony_response(gen_ids)
 
 
 
 
 
 
 
 
315
  else:
316
  # Fallback
317
- decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
318
- channels = {
319
  "final": extract_final_channel_fallback(decoded),
320
  "raw": decoded
321
  }
@@ -367,9 +388,9 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
367
 
368
  # Create Harmony-formatted prompt
369
  if HARMONY_AVAILABLE:
370
- prompt = create_harmony_prompt(messages, reasoning_effort)
371
  else:
372
- # Fallback to tokenizer template
373
  prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
374
 
375
  # Build Rose map if enabled
 
47
  LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
 
49
  # Harmony channels for CoT
50
+ REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
51
 
52
  # HF Auth - properly handle multiple token env var names
53
  HF_TOKEN: Optional[str] = (
 
84
  else:
85
  harmony_encoding = None
86
 
87
+ # Stop tokens per Harmony spec: <|return|> (200002), <|call|> (200012)
88
+ HARMONY_STOP_IDS = [200002, 200012]
89
+
90
  # Tokenizer is lightweight; load once
91
  try:
92
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
 
147
  # Harmony formatting
148
  # -----------------------
149
 
150
+ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
151
  """Create a proper Harmony-formatted prompt using openai_harmony."""
152
  if not HARMONY_AVAILABLE:
153
  # Fallback to tokenizer's chat template
 
192
  # Create conversation and render
193
  convo = Conversation.from_messages(harmony_messages)
194
  tokens = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
195
+ return tokens # pass tokens directly to the model to avoid decode/re-encode drift
 
 
196
 
197
  def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
198
  """Parse response tokens using Harmony format to extract channels."""
 
269
  return scores + self.alpha * self.bias_vec.to(scores.device)
270
 
271
  @spaces.GPU(duration=120)
272
+ def zerogpu_generate(full_prompt,
273
  gen_kwargs: Dict[str, Any],
274
  rose_map: Optional[Dict[str, float]],
275
  rose_alpha: float,
 
290
  eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
291
  logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
292
 
293
+ # Tokenize / prepare inputs
294
+ device = next(model.parameters()).device
295
+ if HARMONY_AVAILABLE and isinstance(full_prompt, list):
296
+ input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
297
+ inputs = {"input_ids": input_ids}
298
+ prompt_len = input_ids.shape[1]
299
+ else:
300
+ enc = tokenizer(full_prompt, return_tensors="pt")
301
+ inputs = enc.to(device)
302
+ prompt_len = int(inputs["input_ids"].shape[1])
303
 
304
  # Generate
305
  out_ids = model.generate(
 
310
  top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
311
  max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
312
  pad_token_id=tokenizer.eos_token_id,
313
+ eos_token_id=(HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id),
314
  logits_processor=logits_processor,
315
  )
316
 
317
  # Extract generated tokens only
318
+ out_list = out_ids[0].tolist()
319
+ gen_ids = out_list[prompt_len:]
320
+ # Truncate at first Harmony stop token if present
321
+ if HARMONY_AVAILABLE:
322
+ for sid in HARMONY_STOP_IDS:
323
+ if sid in gen_ids:
324
+ gen_ids = gen_ids[:gen_ids.index(sid)]
325
+ break
326
 
327
  # Parse response with Harmony
328
  if HARMONY_AVAILABLE:
329
+ try:
330
+ channels = parse_harmony_response(gen_ids)
331
+ except Exception:
332
+ # Fallback to text parsing if Harmony parser fails
333
+ decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
334
+ channels = {
335
+ "final": extract_final_channel_fallback(decoded),
336
+ "raw": decoded
337
+ }
338
  else:
339
  # Fallback
 
 
340
  "final": extract_final_channel_fallback(decoded),
341
  "raw": decoded
342
  }
 
388
 
389
  # Create Harmony-formatted prompt
390
  if HARMONY_AVAILABLE:
391
+ prompt = create_harmony_prompt(messages, reasoning_effort) # returns token IDs
392
  else:
393
+ # Fallback to tokenizer template (string)
394
  prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
395
 
396
  # Build Rose map if enabled
requirements.txt CHANGED
@@ -4,4 +4,5 @@ accelerate>=0.33.0
4
  peft>=0.11.0
5
  torch>=2.4.0 # ZeroGPU-supported (2.3.x is NOT supported)
6
  bitsandbytes>=0.43.1
7
- openai_harmony
 
 
4
  peft>=0.11.0
5
  torch>=2.4.0 # ZeroGPU-supported (2.3.x is NOT supported)
6
  bitsandbytes>=0.43.1
7
+ openai_harmony
8
+ gradio>=5.42.0