Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
1b343d0
1
Parent(s):
7d19f11
okay it works better now
Browse files- app.py +37 -16
- requirements.txt +2 -1
app.py
CHANGED
@@ -47,7 +47,7 @@ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
|
|
47 |
LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
|
48 |
|
49 |
# Harmony channels for CoT
|
50 |
-
REQUIRED_CHANNELS = ["
|
51 |
|
52 |
# HF Auth - properly handle multiple token env var names
|
53 |
HF_TOKEN: Optional[str] = (
|
@@ -84,6 +84,9 @@ if HARMONY_AVAILABLE:
|
|
84 |
else:
|
85 |
harmony_encoding = None
|
86 |
|
|
|
|
|
|
|
87 |
# Tokenizer is lightweight; load once
|
88 |
try:
|
89 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
|
@@ -144,7 +147,7 @@ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
|
|
144 |
# Harmony formatting
|
145 |
# -----------------------
|
146 |
|
147 |
-
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") ->
|
148 |
"""Create a proper Harmony-formatted prompt using openai_harmony."""
|
149 |
if not HARMONY_AVAILABLE:
|
150 |
# Fallback to tokenizer's chat template
|
@@ -189,9 +192,7 @@ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str
|
|
189 |
# Create conversation and render
|
190 |
convo = Conversation.from_messages(harmony_messages)
|
191 |
tokens = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
192 |
-
|
193 |
-
# Convert tokens back to text for the model
|
194 |
-
return tokenizer.decode(tokens)
|
195 |
|
196 |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
197 |
"""Parse response tokens using Harmony format to extract channels."""
|
@@ -268,7 +269,7 @@ class RoseGuidedLogits(torch.nn.Module):
|
|
268 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
269 |
|
270 |
@spaces.GPU(duration=120)
|
271 |
-
def zerogpu_generate(full_prompt
|
272 |
gen_kwargs: Dict[str, Any],
|
273 |
rose_map: Optional[Dict[str, float]],
|
274 |
rose_alpha: float,
|
@@ -289,8 +290,16 @@ def zerogpu_generate(full_prompt: str,
|
|
289 |
eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
|
290 |
logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
|
291 |
|
292 |
-
# Tokenize
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
# Generate
|
296 |
out_ids = model.generate(
|
@@ -301,21 +310,33 @@ def zerogpu_generate(full_prompt: str,
|
|
301 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
302 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
303 |
pad_token_id=tokenizer.eos_token_id,
|
304 |
-
eos_token_id=tokenizer.eos_token_id,
|
305 |
logits_processor=logits_processor,
|
306 |
)
|
307 |
|
308 |
# Extract generated tokens only
|
309 |
-
|
310 |
-
gen_ids =
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
|
312 |
# Parse response with Harmony
|
313 |
if HARMONY_AVAILABLE:
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
else:
|
316 |
# Fallback
|
317 |
-
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
318 |
-
channels = {
|
319 |
"final": extract_final_channel_fallback(decoded),
|
320 |
"raw": decoded
|
321 |
}
|
@@ -367,9 +388,9 @@ def generate_response(message: str, history: List[List[str]], system_prompt: str
|
|
367 |
|
368 |
# Create Harmony-formatted prompt
|
369 |
if HARMONY_AVAILABLE:
|
370 |
-
prompt = create_harmony_prompt(messages, reasoning_effort)
|
371 |
else:
|
372 |
-
# Fallback to tokenizer template
|
373 |
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
374 |
|
375 |
# Build Rose map if enabled
|
|
|
47 |
LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
|
48 |
|
49 |
# Harmony channels for CoT
|
50 |
+
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
|
51 |
|
52 |
# HF Auth - properly handle multiple token env var names
|
53 |
HF_TOKEN: Optional[str] = (
|
|
|
84 |
else:
|
85 |
harmony_encoding = None
|
86 |
|
87 |
+
# Stop tokens per Harmony spec: <|return|> (200002), <|call|> (200012)
|
88 |
+
HARMONY_STOP_IDS = [200002, 200012]
|
89 |
+
|
90 |
# Tokenizer is lightweight; load once
|
91 |
try:
|
92 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
|
|
|
147 |
# Harmony formatting
|
148 |
# -----------------------
|
149 |
|
150 |
+
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
|
151 |
"""Create a proper Harmony-formatted prompt using openai_harmony."""
|
152 |
if not HARMONY_AVAILABLE:
|
153 |
# Fallback to tokenizer's chat template
|
|
|
192 |
# Create conversation and render
|
193 |
convo = Conversation.from_messages(harmony_messages)
|
194 |
tokens = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
|
195 |
+
return tokens # pass tokens directly to the model to avoid decode/re-encode drift
|
|
|
|
|
196 |
|
197 |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
|
198 |
"""Parse response tokens using Harmony format to extract channels."""
|
|
|
269 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
270 |
|
271 |
@spaces.GPU(duration=120)
|
272 |
+
def zerogpu_generate(full_prompt,
|
273 |
gen_kwargs: Dict[str, Any],
|
274 |
rose_map: Optional[Dict[str, float]],
|
275 |
rose_alpha: float,
|
|
|
290 |
eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
|
291 |
logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
|
292 |
|
293 |
+
# Tokenize / prepare inputs
|
294 |
+
device = next(model.parameters()).device
|
295 |
+
if HARMONY_AVAILABLE and isinstance(full_prompt, list):
|
296 |
+
input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
|
297 |
+
inputs = {"input_ids": input_ids}
|
298 |
+
prompt_len = input_ids.shape[1]
|
299 |
+
else:
|
300 |
+
enc = tokenizer(full_prompt, return_tensors="pt")
|
301 |
+
inputs = enc.to(device)
|
302 |
+
prompt_len = int(inputs["input_ids"].shape[1])
|
303 |
|
304 |
# Generate
|
305 |
out_ids = model.generate(
|
|
|
310 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
311 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
312 |
pad_token_id=tokenizer.eos_token_id,
|
313 |
+
eos_token_id=(HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id),
|
314 |
logits_processor=logits_processor,
|
315 |
)
|
316 |
|
317 |
# Extract generated tokens only
|
318 |
+
out_list = out_ids[0].tolist()
|
319 |
+
gen_ids = out_list[prompt_len:]
|
320 |
+
# Truncate at first Harmony stop token if present
|
321 |
+
if HARMONY_AVAILABLE:
|
322 |
+
for sid in HARMONY_STOP_IDS:
|
323 |
+
if sid in gen_ids:
|
324 |
+
gen_ids = gen_ids[:gen_ids.index(sid)]
|
325 |
+
break
|
326 |
|
327 |
# Parse response with Harmony
|
328 |
if HARMONY_AVAILABLE:
|
329 |
+
try:
|
330 |
+
channels = parse_harmony_response(gen_ids)
|
331 |
+
except Exception:
|
332 |
+
# Fallback to text parsing if Harmony parser fails
|
333 |
+
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
|
334 |
+
channels = {
|
335 |
+
"final": extract_final_channel_fallback(decoded),
|
336 |
+
"raw": decoded
|
337 |
+
}
|
338 |
else:
|
339 |
# Fallback
|
|
|
|
|
340 |
"final": extract_final_channel_fallback(decoded),
|
341 |
"raw": decoded
|
342 |
}
|
|
|
388 |
|
389 |
# Create Harmony-formatted prompt
|
390 |
if HARMONY_AVAILABLE:
|
391 |
+
prompt = create_harmony_prompt(messages, reasoning_effort) # returns token IDs
|
392 |
else:
|
393 |
+
# Fallback to tokenizer template (string)
|
394 |
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
395 |
|
396 |
# Build Rose map if enabled
|
requirements.txt
CHANGED
@@ -4,4 +4,5 @@ accelerate>=0.33.0
|
|
4 |
peft>=0.11.0
|
5 |
torch>=2.4.0 # ZeroGPU-supported (2.3.x is NOT supported)
|
6 |
bitsandbytes>=0.43.1
|
7 |
-
openai_harmony
|
|
|
|
4 |
peft>=0.11.0
|
5 |
torch>=2.4.0 # ZeroGPU-supported (2.3.x is NOT supported)
|
6 |
bitsandbytes>=0.43.1
|
7 |
+
openai_harmony
|
8 |
+
gradio>=5.42.0
|