AbstractPhil commited on
Commit
02fd900
·
1 Parent(s): c709572

tweaks, maybe works?

Browse files
Files changed (1) hide show
  1. app.py +309 -50
app.py CHANGED
@@ -1,64 +1,323 @@
1
- import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
  """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
8
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
41
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
1
  """
2
+ Mirel Harmony Inference HF Space (Gradio)
3
+ ZeroGPU-ready + Harmony formatting + optional Rose-guided decoding
4
+ Single-file app.py
5
+
6
+ Env (Spaces → Settings → Variables):
7
+ MODEL_ID : base or merged model (e.g., "openai/gpt-oss-20b" or path to merged)
8
+ ADAPTER_ID : optional PEFT repo/path (e.g., "AbstractPhil/mirel-gpt-oss-20b")
9
+ ADAPTER_SUBFOLDER : optional subfolder inside adapter repo (e.g., "checkpoints/checkpoint-516")
10
+ ZEROGPU : "1" to enable lazy load/unload (memory saver)
11
+ LOAD_4BIT : "1" to attempt 4-bit (bitsandbytes)
12
+ DTYPE : bf16 | fp16 | fp32 (default bf16)
13
+ ATTN_IMPL : eager | flash_attention_2 (default eager)
14
+ SYSTEM_PROMPT : default system message (default: "You are Mirel.")
15
+ MAX_NEW_TOKENS : default max new tokens (int, default 512)
16
+ HF_TOKEN : (optional) if you need private repo access
17
+
18
+ requirements.txt:
19
+ transformers>=4.43.0
20
+ accelerate>=0.33.0
21
+ peft>=0.11.0
22
+ gradio>=4.36.0
23
+ torch>=2.3.0
24
+ bitsandbytes>=0.43.1
25
  """
26
+ from __future__ import annotations
27
+ import os, gc, json, threading, torch
28
+ from dataclasses import dataclass
29
+ from typing import List, Dict, Optional, Any
30
+ import gradio as gr
31
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
32
 
33
+ # -----------------------
34
+ # Config & runtime modes
35
+ # -----------------------
36
+ DTYPE_MAP = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
37
 
38
+ MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
39
+ ADAPTER_ID = os.getenv("ADAPTER_ID") or None
40
+ ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
41
+ ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
42
+ DTYPE = DTYPE_MAP.get(os.getenv("DTYPE", "bf16").lower(), torch.bfloat16)
43
+ SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel.")
44
+ MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "512"))
45
+ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
46
+ LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
47
 
48
+ # Optional: authenticate if HF_TOKEN provided (for private artifacts)
49
+ HF_TOKEN = os.getenv("HF_TOKEN")
50
+ if HF_TOKEN:
51
+ try:
52
+ from huggingface_hub import login
53
+ login(token=HF_TOKEN, add_to_git_credential=True)
54
+ except Exception:
55
+ pass
56
 
57
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
58
 
59
+ # Tokenizer is lightweight; load once
60
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
61
 
62
+ # -----------------------
63
+ # Lazy model loader (ZeroGPU-ready)
64
+ # -----------------------
65
+ _model = None
66
+ _model_lock = threading.Lock()
 
 
 
67
 
68
+ try:
69
+ from peft import PeftModel
70
+ _HAS_PEFT = True
71
+ except Exception:
72
+ _HAS_PEFT = False
73
 
74
 
75
+ def _load_model() -> AutoModelForCausalLM:
76
+ """Load model (and adapter if provided). In ZEROGPU mode, place on cuda if available, else cpu."""
77
+ global _model
78
+ if _model is not None:
79
+ return _model
80
+
81
+ kwargs: Dict[str, Any] = dict(
82
+ torch_dtype=DTYPE,
83
+ device_map=None if ZEROGPU else "auto",
84
+ attn_implementation=ATTN_IMPL,
85
+ trust_remote_code=True,
86
+ )
87
+ if LOAD_4BIT:
88
+ try:
89
+ import bitsandbytes as bnb # noqa: F401
90
+ kwargs.update(load_in_4bit=True)
91
+ except Exception:
92
+ pass
93
+
94
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
95
+
96
+ if ADAPTER_ID:
97
+ if not _HAS_PEFT:
98
+ raise RuntimeError("peft is required when ADAPTER_ID is set.")
99
+ peft_kwargs = dict()
100
+ if ADAPTER_SUBFOLDER:
101
+ peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
102
+ model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
103
+
104
+ model.eval()
105
+ model.config.use_cache = True
106
+
107
+ # In ZeroGPU we control placement explicitly
108
+ if ZEROGPU:
109
+ if torch.cuda.is_available():
110
+ model = model.to("cuda")
111
+ else:
112
+ model = model.to("cpu")
113
+
114
+ _model = model
115
+ return _model
116
+
117
+
118
+ def _unload_model_if_zerogpu():
119
+ """Aggressive unload to cooperate with ZeroGPU/limited VRAM."""
120
+ global _model
121
+ if not ZEROGPU:
122
+ return
123
+ try:
124
+ if _model is not None:
125
+ _model.to("cpu")
126
+ del _model
127
+ except Exception:
128
+ pass
129
+ _model = None
130
+ gc.collect()
131
+ if torch.cuda.is_available():
132
+ torch.cuda.empty_cache()
133
+
134
+ # -----------------------
135
+ # Harmony formatting
136
+ # -----------------------
137
+
138
+ def to_harmony_prompt(messages: List[Dict[str, str]]) -> str:
139
+ """Prefer tokenizer.chat_template; fallback to minimal Harmony-like format."""
140
+ tmpl = getattr(tokenizer, "chat_template", None)
141
+ if tmpl:
142
+ return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
143
+ sys_txt = ""
144
+ if messages and messages[0]["role"] == "system":
145
+ sys_txt = "<<SYS>>\n" + messages[0]["content"] + "\n<</SYS>>\n\n"
146
+ messages = messages[1:]
147
+ convo = []
148
+ for m in messages:
149
+ if m["role"] == "user":
150
+ convo.append("<|user|>\n" + m["content"] + "\n<|end|>")
151
+ elif m["role"] == "assistant":
152
+ convo.append("<|assistant|>\n" + m["content"] + "\n<|end|>")
153
+ return sys_txt + "\n".join(convo) + "\n<|assistant|>\n"
154
+
155
+ # -----------------------
156
+ # Optional Rose guidance (logits bias)
157
+ # -----------------------
158
+
159
+ def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
160
+ """Create vocab bias from {token: weight}. Unknown tokens ignored. Positive promotes, negative demotes."""
161
+ vocab_size = len(tokenizer)
162
+ bias = torch.zeros(vocab_size, dtype=torch.float32)
163
+ for tok, w in mapping.items():
164
+ if tok is None:
165
+ continue
166
+ tid = tokenizer.convert_tokens_to_ids(tok)
167
+ if isinstance(tid, list):
168
+ for t in tid:
169
+ if isinstance(t, int) and t >= 0:
170
+ bias[t] += float(w) / max(1, len(tid))
171
+ elif isinstance(tid, int) and tid >= 0:
172
+ bias[tid] += float(w)
173
+ return bias
174
+
175
+ class RoseGuidedLogits(torch.nn.Module):
176
+ def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0):
177
+ super().__init__()
178
+ self.bias_vec = bias_vec
179
+ self.alpha = float(alpha)
180
+ def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
181
+ return scores + self.alpha * self.bias_vec.to(scores.device)
182
+
183
+ # -----------------------
184
+ # Gradio handlers
185
+ # -----------------------
186
+ @dataclass
187
+ class GenCfg:
188
+ temperature: float
189
+ top_p: float
190
+ top_k: int
191
+ max_new_tokens: int
192
+ do_sample: bool
193
+ seed: Optional[int]
194
+
195
+
196
+ def chat_to_messages(history: List[List[str]], system_prompt: str) -> List[Dict[str, str]]:
197
+ msgs: List[Dict[str, str]] = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
198
+ for u, a in history:
199
+ if u is not None:
200
+ msgs.append({"role": "user", "content": u})
201
+ if a:
202
+ msgs.append({"role": "assistant", "content": a})
203
+ return msgs
204
+
205
+
206
+ def generate_stream(message: str, history: List[List[str]], system_prompt: str,
207
+ temperature: float, top_p: float, top_k: int, max_new_tokens: int,
208
+ do_sample: bool, seed: int | None,
209
+ rose_enable: bool, rose_alpha: float, rose_tokens: str, rose_json: str):
210
+ cfg = GenCfg(temperature, top_p, top_k, max_new_tokens, do_sample, seed)
211
+ if cfg.seed is not None:
212
+ torch.manual_seed(int(cfg.seed))
213
+
214
+ msgs = chat_to_messages(history, system_prompt)
215
+ msgs.append({"role": "user", "content": message})
216
+ prompt = to_harmony_prompt(msgs)
217
+
218
+ # Lazy load
219
+ global _model
220
+ with _model_lock:
221
+ _model = _load_model()
222
+
223
+ inputs = tokenizer(prompt, return_tensors="pt").to(_model.device)
224
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
225
+
226
+ # Optional logits processor
227
+ logits_processor = None
228
+ if rose_enable:
229
+ token_map: Dict[str, float] = {}
230
+ rose_tokens = (rose_tokens or "").strip()
231
+ if rose_tokens:
232
+ # "token:weight, token2:weight"
233
+ parts = [p.strip() for p in rose_tokens.split(",") if p.strip()]
234
+ for p in parts:
235
+ if ":" in p:
236
+ k, v = p.split(":", 1)
237
+ try:
238
+ token_map[k.strip()] = float(v)
239
+ except Exception:
240
+ pass
241
+ if rose_json:
242
+ try:
243
+ j = json.loads(rose_json)
244
+ if isinstance(j, dict):
245
+ for k, v in j.items():
246
+ try:
247
+ token_map[str(k)] = float(v)
248
+ except Exception:
249
+ pass
250
+ except Exception:
251
+ pass
252
+ if token_map:
253
+ bias = build_bias_from_tokens(tokenizer, token_map).to(_model.device)
254
+ logits_processor = [RoseGuidedLogits(bias, rose_alpha)]
255
+
256
+ gen_kwargs = dict(
257
+ **inputs,
258
+ do_sample=cfg.do_sample,
259
+ temperature=cfg.temperature,
260
+ top_p=cfg.top_p,
261
+ top_k=cfg.top_k if cfg.top_k > 0 else None,
262
+ max_new_tokens=cfg.max_new_tokens,
263
+ pad_token_id=tokenizer.eos_token_id,
264
+ streamer=streamer,
265
+ logits_processor=logits_processor,
266
+ )
267
+
268
+ thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
269
+ thread.start()
270
+ partial = ""
271
+ for token in streamer:
272
+ partial += token
273
+ yield partial
274
+
275
+ if ZEROGPU:
276
+ with _model_lock:
277
+ _unload_model_if_zerogpu()
278
+
279
+ # -----------------------
280
+ # UI
281
+ # -----------------------
282
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
283
+ gr.Markdown("""
284
+ # Mirel – Harmony Inference (ZeroGPU‑ready)
285
+ OSS‑20B + optional Rose‑SFT adapter. Harmony chat template is applied automatically.
286
+ """)
287
+
288
+ with gr.Row():
289
+ system_prompt = gr.Textbox(label="System", value=SYSTEM_DEF)
290
+ with gr.Accordion("Generation settings", open=False):
291
+ with gr.Row():
292
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
293
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="top_p")
294
+ top_k = gr.Slider(0, 200, value=0, step=1, label="top_k (0=off)")
295
+ max_new = gr.Slider(16, 2048, value=MAX_DEF, step=8, label="max_new_tokens")
296
+ do_sample = gr.Checkbox(value=True, label="do_sample")
297
+ seed = gr.Number(value=None, label="seed (optional)")
298
+ with gr.Accordion("Rose guidance (optional)", open=False):
299
+ with gr.Row():
300
+ rose_enable = gr.Checkbox(value=False, label="Enable Rose bias at decode")
301
+ rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="rose alpha (strength)")
302
+ rose_tokens = gr.Textbox(label="token:weight list (comma-separated)", value="")
303
+ rose_json = gr.Textbox(label="JSON {token: weight}", value="")
304
+
305
+ chat = gr.ChatInterface(
306
+ fn=generate_stream,
307
+ chatbot=gr.Chatbot(show_copy_button=True, likeable=True, render_markdown=True),
308
+ additional_inputs=[system_prompt, temperature, top_p, top_k, max_new, do_sample, seed, rose_enable, rose_alpha, rose_tokens, rose_json],
309
+ title="Mirel",
310
+ concurrency_limit=2 if ZEROGPU else 4,
311
+ cache_examples=False,
312
+ )
313
 
314
+ gr.Markdown("""
315
+ **Notes**
316
+ - Set env `ZEROGPU=1` to enable just‑in‑time load and aggressive unload per request.
317
+ - Set `ADAPTER_ID=AbstractPhil/mirel-gpt-oss-20b` and `ADAPTER_SUBFOLDER=checkpoints/checkpoint-516` to use the provided adapter.
318
+ - For large contexts on A100/H100 prefer `DTYPE=bf16` and `ATTN_IMPL=eager` unless FA2 is installed.
319
+ - Rose guidance is optional; it biases logits without changing model weights.
320
+ """)
321
 
322
  if __name__ == "__main__":
323
+ demo.queue(max_size=8 if ZEROGPU else 32, concurrency_count=2 if ZEROGPU else 4).launch(server_name="0.0.0.0", server_port=7860)