Spaces:
Running
on
Zero
Running
on
Zero
AbstractPhil
commited on
Commit
·
02fd900
1
Parent(s):
c709572
tweaks, maybe works?
Browse files
app.py
CHANGED
@@ -1,64 +1,323 @@
|
|
1 |
-
import gradio as gr
|
2 |
-
from huggingface_hub import InferenceClient
|
3 |
-
|
4 |
"""
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
"""
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
8 |
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
)
|
18 |
-
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
25 |
|
26 |
-
|
27 |
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
41 |
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
+
Mirel Harmony Inference – HF Space (Gradio)
|
3 |
+
ZeroGPU-ready + Harmony formatting + optional Rose-guided decoding
|
4 |
+
Single-file app.py
|
5 |
+
|
6 |
+
Env (Spaces → Settings → Variables):
|
7 |
+
MODEL_ID : base or merged model (e.g., "openai/gpt-oss-20b" or path to merged)
|
8 |
+
ADAPTER_ID : optional PEFT repo/path (e.g., "AbstractPhil/mirel-gpt-oss-20b")
|
9 |
+
ADAPTER_SUBFOLDER : optional subfolder inside adapter repo (e.g., "checkpoints/checkpoint-516")
|
10 |
+
ZEROGPU : "1" to enable lazy load/unload (memory saver)
|
11 |
+
LOAD_4BIT : "1" to attempt 4-bit (bitsandbytes)
|
12 |
+
DTYPE : bf16 | fp16 | fp32 (default bf16)
|
13 |
+
ATTN_IMPL : eager | flash_attention_2 (default eager)
|
14 |
+
SYSTEM_PROMPT : default system message (default: "You are Mirel.")
|
15 |
+
MAX_NEW_TOKENS : default max new tokens (int, default 512)
|
16 |
+
HF_TOKEN : (optional) if you need private repo access
|
17 |
+
|
18 |
+
requirements.txt:
|
19 |
+
transformers>=4.43.0
|
20 |
+
accelerate>=0.33.0
|
21 |
+
peft>=0.11.0
|
22 |
+
gradio>=4.36.0
|
23 |
+
torch>=2.3.0
|
24 |
+
bitsandbytes>=0.43.1
|
25 |
"""
|
26 |
+
from __future__ import annotations
|
27 |
+
import os, gc, json, threading, torch
|
28 |
+
from dataclasses import dataclass
|
29 |
+
from typing import List, Dict, Optional, Any
|
30 |
+
import gradio as gr
|
31 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
32 |
|
33 |
+
# -----------------------
|
34 |
+
# Config & runtime modes
|
35 |
+
# -----------------------
|
36 |
+
DTYPE_MAP = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
|
37 |
|
38 |
+
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
|
39 |
+
ADAPTER_ID = os.getenv("ADAPTER_ID") or None
|
40 |
+
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
|
41 |
+
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
|
42 |
+
DTYPE = DTYPE_MAP.get(os.getenv("DTYPE", "bf16").lower(), torch.bfloat16)
|
43 |
+
SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel.")
|
44 |
+
MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "512"))
|
45 |
+
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
|
46 |
+
LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
|
47 |
|
48 |
+
# Optional: authenticate if HF_TOKEN provided (for private artifacts)
|
49 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
50 |
+
if HF_TOKEN:
|
51 |
+
try:
|
52 |
+
from huggingface_hub import login
|
53 |
+
login(token=HF_TOKEN, add_to_git_credential=True)
|
54 |
+
except Exception:
|
55 |
+
pass
|
56 |
|
57 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
58 |
|
59 |
+
# Tokenizer is lightweight; load once
|
60 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
61 |
|
62 |
+
# -----------------------
|
63 |
+
# Lazy model loader (ZeroGPU-ready)
|
64 |
+
# -----------------------
|
65 |
+
_model = None
|
66 |
+
_model_lock = threading.Lock()
|
|
|
|
|
|
|
67 |
|
68 |
+
try:
|
69 |
+
from peft import PeftModel
|
70 |
+
_HAS_PEFT = True
|
71 |
+
except Exception:
|
72 |
+
_HAS_PEFT = False
|
73 |
|
74 |
|
75 |
+
def _load_model() -> AutoModelForCausalLM:
|
76 |
+
"""Load model (and adapter if provided). In ZEROGPU mode, place on cuda if available, else cpu."""
|
77 |
+
global _model
|
78 |
+
if _model is not None:
|
79 |
+
return _model
|
80 |
+
|
81 |
+
kwargs: Dict[str, Any] = dict(
|
82 |
+
torch_dtype=DTYPE,
|
83 |
+
device_map=None if ZEROGPU else "auto",
|
84 |
+
attn_implementation=ATTN_IMPL,
|
85 |
+
trust_remote_code=True,
|
86 |
+
)
|
87 |
+
if LOAD_4BIT:
|
88 |
+
try:
|
89 |
+
import bitsandbytes as bnb # noqa: F401
|
90 |
+
kwargs.update(load_in_4bit=True)
|
91 |
+
except Exception:
|
92 |
+
pass
|
93 |
+
|
94 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
|
95 |
+
|
96 |
+
if ADAPTER_ID:
|
97 |
+
if not _HAS_PEFT:
|
98 |
+
raise RuntimeError("peft is required when ADAPTER_ID is set.")
|
99 |
+
peft_kwargs = dict()
|
100 |
+
if ADAPTER_SUBFOLDER:
|
101 |
+
peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
|
102 |
+
model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
|
103 |
+
|
104 |
+
model.eval()
|
105 |
+
model.config.use_cache = True
|
106 |
+
|
107 |
+
# In ZeroGPU we control placement explicitly
|
108 |
+
if ZEROGPU:
|
109 |
+
if torch.cuda.is_available():
|
110 |
+
model = model.to("cuda")
|
111 |
+
else:
|
112 |
+
model = model.to("cpu")
|
113 |
+
|
114 |
+
_model = model
|
115 |
+
return _model
|
116 |
+
|
117 |
+
|
118 |
+
def _unload_model_if_zerogpu():
|
119 |
+
"""Aggressive unload to cooperate with ZeroGPU/limited VRAM."""
|
120 |
+
global _model
|
121 |
+
if not ZEROGPU:
|
122 |
+
return
|
123 |
+
try:
|
124 |
+
if _model is not None:
|
125 |
+
_model.to("cpu")
|
126 |
+
del _model
|
127 |
+
except Exception:
|
128 |
+
pass
|
129 |
+
_model = None
|
130 |
+
gc.collect()
|
131 |
+
if torch.cuda.is_available():
|
132 |
+
torch.cuda.empty_cache()
|
133 |
+
|
134 |
+
# -----------------------
|
135 |
+
# Harmony formatting
|
136 |
+
# -----------------------
|
137 |
+
|
138 |
+
def to_harmony_prompt(messages: List[Dict[str, str]]) -> str:
|
139 |
+
"""Prefer tokenizer.chat_template; fallback to minimal Harmony-like format."""
|
140 |
+
tmpl = getattr(tokenizer, "chat_template", None)
|
141 |
+
if tmpl:
|
142 |
+
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
143 |
+
sys_txt = ""
|
144 |
+
if messages and messages[0]["role"] == "system":
|
145 |
+
sys_txt = "<<SYS>>\n" + messages[0]["content"] + "\n<</SYS>>\n\n"
|
146 |
+
messages = messages[1:]
|
147 |
+
convo = []
|
148 |
+
for m in messages:
|
149 |
+
if m["role"] == "user":
|
150 |
+
convo.append("<|user|>\n" + m["content"] + "\n<|end|>")
|
151 |
+
elif m["role"] == "assistant":
|
152 |
+
convo.append("<|assistant|>\n" + m["content"] + "\n<|end|>")
|
153 |
+
return sys_txt + "\n".join(convo) + "\n<|assistant|>\n"
|
154 |
+
|
155 |
+
# -----------------------
|
156 |
+
# Optional Rose guidance (logits bias)
|
157 |
+
# -----------------------
|
158 |
+
|
159 |
+
def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
|
160 |
+
"""Create vocab bias from {token: weight}. Unknown tokens ignored. Positive promotes, negative demotes."""
|
161 |
+
vocab_size = len(tokenizer)
|
162 |
+
bias = torch.zeros(vocab_size, dtype=torch.float32)
|
163 |
+
for tok, w in mapping.items():
|
164 |
+
if tok is None:
|
165 |
+
continue
|
166 |
+
tid = tokenizer.convert_tokens_to_ids(tok)
|
167 |
+
if isinstance(tid, list):
|
168 |
+
for t in tid:
|
169 |
+
if isinstance(t, int) and t >= 0:
|
170 |
+
bias[t] += float(w) / max(1, len(tid))
|
171 |
+
elif isinstance(tid, int) and tid >= 0:
|
172 |
+
bias[tid] += float(w)
|
173 |
+
return bias
|
174 |
+
|
175 |
+
class RoseGuidedLogits(torch.nn.Module):
|
176 |
+
def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0):
|
177 |
+
super().__init__()
|
178 |
+
self.bias_vec = bias_vec
|
179 |
+
self.alpha = float(alpha)
|
180 |
+
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
181 |
+
return scores + self.alpha * self.bias_vec.to(scores.device)
|
182 |
+
|
183 |
+
# -----------------------
|
184 |
+
# Gradio handlers
|
185 |
+
# -----------------------
|
186 |
+
@dataclass
|
187 |
+
class GenCfg:
|
188 |
+
temperature: float
|
189 |
+
top_p: float
|
190 |
+
top_k: int
|
191 |
+
max_new_tokens: int
|
192 |
+
do_sample: bool
|
193 |
+
seed: Optional[int]
|
194 |
+
|
195 |
+
|
196 |
+
def chat_to_messages(history: List[List[str]], system_prompt: str) -> List[Dict[str, str]]:
|
197 |
+
msgs: List[Dict[str, str]] = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
|
198 |
+
for u, a in history:
|
199 |
+
if u is not None:
|
200 |
+
msgs.append({"role": "user", "content": u})
|
201 |
+
if a:
|
202 |
+
msgs.append({"role": "assistant", "content": a})
|
203 |
+
return msgs
|
204 |
+
|
205 |
+
|
206 |
+
def generate_stream(message: str, history: List[List[str]], system_prompt: str,
|
207 |
+
temperature: float, top_p: float, top_k: int, max_new_tokens: int,
|
208 |
+
do_sample: bool, seed: int | None,
|
209 |
+
rose_enable: bool, rose_alpha: float, rose_tokens: str, rose_json: str):
|
210 |
+
cfg = GenCfg(temperature, top_p, top_k, max_new_tokens, do_sample, seed)
|
211 |
+
if cfg.seed is not None:
|
212 |
+
torch.manual_seed(int(cfg.seed))
|
213 |
+
|
214 |
+
msgs = chat_to_messages(history, system_prompt)
|
215 |
+
msgs.append({"role": "user", "content": message})
|
216 |
+
prompt = to_harmony_prompt(msgs)
|
217 |
+
|
218 |
+
# Lazy load
|
219 |
+
global _model
|
220 |
+
with _model_lock:
|
221 |
+
_model = _load_model()
|
222 |
+
|
223 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(_model.device)
|
224 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
225 |
+
|
226 |
+
# Optional logits processor
|
227 |
+
logits_processor = None
|
228 |
+
if rose_enable:
|
229 |
+
token_map: Dict[str, float] = {}
|
230 |
+
rose_tokens = (rose_tokens or "").strip()
|
231 |
+
if rose_tokens:
|
232 |
+
# "token:weight, token2:weight"
|
233 |
+
parts = [p.strip() for p in rose_tokens.split(",") if p.strip()]
|
234 |
+
for p in parts:
|
235 |
+
if ":" in p:
|
236 |
+
k, v = p.split(":", 1)
|
237 |
+
try:
|
238 |
+
token_map[k.strip()] = float(v)
|
239 |
+
except Exception:
|
240 |
+
pass
|
241 |
+
if rose_json:
|
242 |
+
try:
|
243 |
+
j = json.loads(rose_json)
|
244 |
+
if isinstance(j, dict):
|
245 |
+
for k, v in j.items():
|
246 |
+
try:
|
247 |
+
token_map[str(k)] = float(v)
|
248 |
+
except Exception:
|
249 |
+
pass
|
250 |
+
except Exception:
|
251 |
+
pass
|
252 |
+
if token_map:
|
253 |
+
bias = build_bias_from_tokens(tokenizer, token_map).to(_model.device)
|
254 |
+
logits_processor = [RoseGuidedLogits(bias, rose_alpha)]
|
255 |
+
|
256 |
+
gen_kwargs = dict(
|
257 |
+
**inputs,
|
258 |
+
do_sample=cfg.do_sample,
|
259 |
+
temperature=cfg.temperature,
|
260 |
+
top_p=cfg.top_p,
|
261 |
+
top_k=cfg.top_k if cfg.top_k > 0 else None,
|
262 |
+
max_new_tokens=cfg.max_new_tokens,
|
263 |
+
pad_token_id=tokenizer.eos_token_id,
|
264 |
+
streamer=streamer,
|
265 |
+
logits_processor=logits_processor,
|
266 |
+
)
|
267 |
+
|
268 |
+
thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
|
269 |
+
thread.start()
|
270 |
+
partial = ""
|
271 |
+
for token in streamer:
|
272 |
+
partial += token
|
273 |
+
yield partial
|
274 |
+
|
275 |
+
if ZEROGPU:
|
276 |
+
with _model_lock:
|
277 |
+
_unload_model_if_zerogpu()
|
278 |
+
|
279 |
+
# -----------------------
|
280 |
+
# UI
|
281 |
+
# -----------------------
|
282 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
283 |
+
gr.Markdown("""
|
284 |
+
# Mirel – Harmony Inference (ZeroGPU‑ready)
|
285 |
+
OSS‑20B + optional Rose‑SFT adapter. Harmony chat template is applied automatically.
|
286 |
+
""")
|
287 |
+
|
288 |
+
with gr.Row():
|
289 |
+
system_prompt = gr.Textbox(label="System", value=SYSTEM_DEF)
|
290 |
+
with gr.Accordion("Generation settings", open=False):
|
291 |
+
with gr.Row():
|
292 |
+
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
|
293 |
+
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="top_p")
|
294 |
+
top_k = gr.Slider(0, 200, value=0, step=1, label="top_k (0=off)")
|
295 |
+
max_new = gr.Slider(16, 2048, value=MAX_DEF, step=8, label="max_new_tokens")
|
296 |
+
do_sample = gr.Checkbox(value=True, label="do_sample")
|
297 |
+
seed = gr.Number(value=None, label="seed (optional)")
|
298 |
+
with gr.Accordion("Rose guidance (optional)", open=False):
|
299 |
+
with gr.Row():
|
300 |
+
rose_enable = gr.Checkbox(value=False, label="Enable Rose bias at decode")
|
301 |
+
rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="rose alpha (strength)")
|
302 |
+
rose_tokens = gr.Textbox(label="token:weight list (comma-separated)", value="")
|
303 |
+
rose_json = gr.Textbox(label="JSON {token: weight}", value="")
|
304 |
+
|
305 |
+
chat = gr.ChatInterface(
|
306 |
+
fn=generate_stream,
|
307 |
+
chatbot=gr.Chatbot(show_copy_button=True, likeable=True, render_markdown=True),
|
308 |
+
additional_inputs=[system_prompt, temperature, top_p, top_k, max_new, do_sample, seed, rose_enable, rose_alpha, rose_tokens, rose_json],
|
309 |
+
title="Mirel",
|
310 |
+
concurrency_limit=2 if ZEROGPU else 4,
|
311 |
+
cache_examples=False,
|
312 |
+
)
|
313 |
|
314 |
+
gr.Markdown("""
|
315 |
+
**Notes**
|
316 |
+
- Set env `ZEROGPU=1` to enable just‑in‑time load and aggressive unload per request.
|
317 |
+
- Set `ADAPTER_ID=AbstractPhil/mirel-gpt-oss-20b` and `ADAPTER_SUBFOLDER=checkpoints/checkpoint-516` to use the provided adapter.
|
318 |
+
- For large contexts on A100/H100 prefer `DTYPE=bf16` and `ATTN_IMPL=eager` unless FA2 is installed.
|
319 |
+
- Rose guidance is optional; it biases logits without changing model weights.
|
320 |
+
""")
|
321 |
|
322 |
if __name__ == "__main__":
|
323 |
+
demo.queue(max_size=8 if ZEROGPU else 32, concurrency_count=2 if ZEROGPU else 4).launch(server_name="0.0.0.0", server_port=7860)
|