Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
6eb225b
1
Parent(s):
3efceb8
yes
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ from typing import List, Dict, Optional, Any
|
|
11 |
from datetime import datetime
|
12 |
import gradio as gr
|
13 |
import spaces # required for ZeroGPU
|
14 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
15 |
|
16 |
# Import Harmony components
|
17 |
try:
|
@@ -47,7 +47,7 @@ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
|
|
47 |
LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
|
48 |
|
49 |
# Harmony channels for CoT
|
50 |
-
REQUIRED_CHANNELS = ["analysis", "
|
51 |
|
52 |
# HF Auth - properly handle multiple token env var names
|
53 |
HF_TOKEN: Optional[str] = (
|
@@ -286,6 +286,12 @@ class RoseGuidedLogits(torch.nn.Module):
|
|
286 |
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
287 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
@spaces.GPU(duration=120)
|
290 |
def zerogpu_generate(full_prompt,
|
291 |
gen_kwargs: Dict[str, Any],
|
@@ -310,18 +316,30 @@ def zerogpu_generate(full_prompt,
|
|
310 |
|
311 |
# Tokenize / prepare inputs
|
312 |
device = next(model.parameters()).device
|
313 |
-
if HARMONY_AVAILABLE and isinstance(full_prompt,
|
314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
|
316 |
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
317 |
prompt_len = input_ids.shape[1]
|
318 |
else:
|
319 |
enc = tokenizer(full_prompt, return_tensors="pt")
|
320 |
-
inputs =
|
321 |
prompt_len = int(inputs["input_ids"].shape[1])
|
322 |
-
# Guarantee attention_mask exists; avoids pad==eos ambiguity warnings
|
323 |
if "attention_mask" not in inputs:
|
324 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
# Generate
|
326 |
# Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
|
327 |
eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
|
@@ -334,11 +352,10 @@ def zerogpu_generate(full_prompt,
|
|
334 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
335 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
336 |
pad_token_id=model.config.pad_token_id,
|
337 |
-
eos_token_id=eos_ids,
|
338 |
logits_processor=logits_processor,
|
339 |
-
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.
|
340 |
-
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size",
|
341 |
-
|
342 |
)
|
343 |
|
344 |
# Extract generated tokens only
|
|
|
11 |
from datetime import datetime
|
12 |
import gradio as gr
|
13 |
import spaces # required for ZeroGPU
|
14 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
|
15 |
|
16 |
# Import Harmony components
|
17 |
try:
|
|
|
47 |
LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
|
48 |
|
49 |
# Harmony channels for CoT
|
50 |
+
REQUIRED_CHANNELS = ["analysis", "final"]
|
51 |
|
52 |
# HF Auth - properly handle multiple token env var names
|
53 |
HF_TOKEN: Optional[str] = (
|
|
|
286 |
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
287 |
return scores + self.alpha * self.bias_vec.to(scores.device)
|
288 |
|
289 |
+
class StopOnTokens(StoppingCriteria):
|
290 |
+
def __init__(self, stop_ids: List[int]):
|
291 |
+
self.stop_ids = set(int(s) for s in (stop_ids or []))
|
292 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
|
293 |
+
return int(input_ids[0, -1]) in self.stop_ids
|
294 |
+
|
295 |
@spaces.GPU(duration=120)
|
296 |
def zerogpu_generate(full_prompt,
|
297 |
gen_kwargs: Dict[str, Any],
|
|
|
316 |
|
317 |
# Tokenize / prepare inputs
|
318 |
device = next(model.parameters()).device
|
319 |
+
if HARMONY_AVAILABLE and not isinstance(full_prompt, str):
|
320 |
+
# Accept list/tuple or any iterable of ints from openai_harmony
|
321 |
+
try:
|
322 |
+
token_list = list(full_prompt)
|
323 |
+
except TypeError:
|
324 |
+
token_list = list(getattr(full_prompt, "ids", getattr(full_prompt, "token_ids", [])))
|
325 |
+
if not token_list:
|
326 |
+
raise ValueError("Harmony prompt produced no tokens")
|
327 |
+
input_ids = torch.tensor([token_list], dtype=torch.long, device=device)
|
328 |
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
|
329 |
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
330 |
prompt_len = input_ids.shape[1]
|
331 |
else:
|
332 |
enc = tokenizer(full_prompt, return_tensors="pt")
|
333 |
+
inputs = {k: v.to(device) for k, v in enc.items()}
|
334 |
prompt_len = int(inputs["input_ids"].shape[1])
|
|
|
335 |
if "attention_mask" not in inputs:
|
336 |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
|
337 |
+
|
338 |
+
# Prepare stopping
|
339 |
+
sc = None
|
340 |
+
if HARMONY_AVAILABLE and HARMONY_STOP_IDS:
|
341 |
+
sc = StoppingCriteriaList([StopOnTokens(HARMONY_STOP_IDS)])
|
342 |
+
|
343 |
# Generate
|
344 |
# Build EOS list: use ONLY Harmony assistant-action stops (per OpenAI docs)
|
345 |
eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
|
|
|
352 |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
|
353 |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
|
354 |
pad_token_id=model.config.pad_token_id,
|
|
|
355 |
logits_processor=logits_processor,
|
356 |
+
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.2)),
|
357 |
+
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 8)),
|
358 |
+
stopping_criteria=sc,
|
359 |
)
|
360 |
|
361 |
# Extract generated tokens only
|