AbstractPhil commited on
Commit
ec38870
·
1 Parent(s): 01d8622

refactor smaller

Browse files
Files changed (1) hide show
  1. app.py +301 -457
app.py CHANGED
@@ -1,532 +1,376 @@
1
  """
2
  Mirel Harmony Inference – HF Space (Gradio)
3
- ZeroGPU-ready, Harmony formatting, optional Rose-guided decoding
4
- Chain-of-thought model with proper channel extraction using openai_harmony
5
- Single file: app.py
6
  """
7
- from __future__ import annotations
8
- import os, gc, json, threading, torch
9
- from dataclasses import dataclass
10
- from typing import List, Dict, Optional, Any
11
- from datetime import datetime
12
  import gradio as gr
13
- import spaces # required for ZeroGPU
14
  from transformers import AutoTokenizer, AutoModelForCausalLM
15
 
16
- # Import Harmony components
17
  try:
18
- from openai_harmony import (
19
- Author,
20
- Conversation,
21
- HarmonyEncodingName,
22
- Message,
23
- Role,
24
- SystemContent,
25
- DeveloperContent,
26
- load_harmony_encoding,
27
- ReasoningEffort
28
- )
29
- HARMONY_AVAILABLE = True
30
  except ImportError:
31
- print("[WARNING] openai_harmony not installed. Install with: pip install openai-harmony")
32
- HARMONY_AVAILABLE = False
33
 
34
  # -----------------------
35
- # Config & runtime modes
36
  # -----------------------
37
- DTYPE_MAP = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
38
-
39
- MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
40
- ADAPTER_ID = os.getenv("ADAPTER_ID") or None
41
- ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
42
- ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
43
- DTYPE = DTYPE_MAP.get(os.getenv("DTYPE", "bf16").lower(), torch.bfloat16)
44
- SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
45
- MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
46
- ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
47
- LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
-
49
- # Harmony channels for CoT
50
- REQUIRED_CHANNELS = ["thinking", "analysis", "final"]
51
-
52
- # HF Auth - properly handle multiple token env var names
53
- HF_TOKEN: Optional[str] = (
54
  os.getenv("HF_TOKEN")
55
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
56
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
57
- or os.getenv("HF_ACCESS_TOKEN")
58
  )
59
 
60
- def _hf_login() -> None:
61
- """Login to HF Hub using common env secret names."""
62
- if HF_TOKEN:
63
- try:
64
- from huggingface_hub import login, whoami
65
- login(token=HF_TOKEN, add_to_git_credential=True)
66
- try:
67
- who = whoami(token=HF_TOKEN)
68
- print(f"[HF Auth] Logged in as: {who.get('name') or who.get('fullname') or who.get('id', 'unknown')}")
69
- except Exception:
70
- print("[HF Auth] Login successful but couldn't get user info")
71
- except Exception as e:
72
- print(f"[HF Auth] Login failed: {e}")
73
- else:
74
- print("[HF Auth] No token found in environment variables")
75
-
76
- # Login before loading any models
77
- _hf_login()
78
-
79
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
80
-
81
- # Load Harmony encoding if available
82
- if HARMONY_AVAILABLE:
83
- harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
84
- else:
85
- harmony_encoding = None
86
-
87
- # Tokenizer is lightweight; load once
88
- try:
89
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
90
- print(f"[Model] Successfully loaded tokenizer from {MODEL_ID}")
91
- except Exception as e:
92
- print(f"[Model] Failed to load tokenizer: {e}")
93
- raise
94
 
95
  # -----------------------
96
- # Model loading
97
  # -----------------------
98
- try:
99
- from peft import PeftModel
100
- _HAS_PEFT = True
101
- except Exception:
102
- _HAS_PEFT = False
103
-
104
 
105
- def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
106
- kw: Dict[str, Any] = dict(
107
- torch_dtype=DTYPE,
108
- device_map=device_map,
109
- attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
110
- trust_remote_code=True,
111
- low_cpu_mem_usage=True,
112
- token=HF_TOKEN,
113
- )
114
- if LOAD_4BIT and device_map != "cpu":
115
- try:
116
- import bitsandbytes as _bnb
117
- kw.update(load_in_4bit=True)
118
- if kw["device_map"] is None:
119
- kw["device_map"] = "auto"
120
- except Exception:
121
- pass
122
- return kw
123
 
 
 
 
 
 
 
 
 
124
 
125
- def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
126
- print(f"[Model] Loading base model from {MODEL_ID}...")
127
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
128
-
129
- if ADAPTER_ID:
130
- if not _HAS_PEFT:
131
- raise RuntimeError("peft is required when ADAPTER_ID is set.")
132
- print(f"[Model] Loading adapter from {ADAPTER_ID}...")
133
- peft_kwargs: Dict[str, Any] = {"token": HF_TOKEN}
134
- if ADAPTER_SUBFOLDER:
135
- peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
136
- model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
137
-
138
- model.eval()
139
- model.config.use_cache = True
140
- print("[Model] Model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return model
142
 
143
- # -----------------------
144
- # Harmony formatting
145
- # -----------------------
146
-
147
- def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> str:
148
- """Create a proper Harmony-formatted prompt using openai_harmony."""
149
- if not HARMONY_AVAILABLE:
150
- # Fallback to tokenizer's chat template
151
- return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
152
-
153
- # Map reasoning effort
154
- effort_map = {
155
- "low": ReasoningEffort.LOW,
156
- "medium": ReasoningEffort.MEDIUM,
157
- "high": ReasoningEffort.HIGH,
158
- }
159
- effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH)
160
-
161
- # Create system message with channels
162
- system_content = (
163
- SystemContent.new()
164
- .with_model_identity(messages[0]["content"] if messages else SYSTEM_DEF)
165
- .with_reasoning_effort(effort)
166
- .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
167
- .with_knowledge_cutoff("2025-01")
168
- .with_required_channels(REQUIRED_CHANNELS)
169
- )
170
-
171
- # Build conversation
172
- harmony_messages = [
173
- Message.from_role_and_content(Role.SYSTEM, system_content)
174
- ]
175
-
176
- # Add user/assistant messages
177
- for msg in messages[1:]: # Skip system message as we already added it
178
- if msg["role"] == "user":
179
- harmony_messages.append(
180
- Message.from_role_and_content(Role.USER, msg["content"])
181
- )
182
- elif msg["role"] == "assistant":
183
- # For assistant messages, we might want to preserve channels if they exist
184
- harmony_messages.append(
185
- Message.from_role_and_content(Role.ASSISTANT, msg["content"])
186
- .with_channel("final") # Default to final channel
187
- )
188
-
189
- # Create conversation and render
190
- convo = Conversation.from_messages(harmony_messages)
191
- tokens = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
192
-
193
- # Convert tokens back to text for the model
194
- return tokenizer.decode(tokens)
195
-
196
- def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
197
- """Parse response tokens using Harmony format to extract channels."""
198
- if not HARMONY_AVAILABLE:
199
- # Fallback: just decode and extract final channel manually
200
- text = tokenizer.decode(tokens, skip_special_tokens=False)
201
- return {"final": extract_final_channel_fallback(text), "raw": text}
202
-
203
- # Parse messages from completion tokens
204
- parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
205
-
206
- # Extract content by channel
207
- channels = {}
208
- for msg in parsed_messages:
209
- channel = msg.channel if hasattr(msg, 'channel') else "final"
210
- if channel not in channels:
211
- channels[channel] = ""
212
- channels[channel] += msg.content
213
-
214
- # Ensure we have a final channel
215
- if "final" not in channels:
216
- channels["final"] = " ".join(channels.values())
217
-
218
- return channels
219
-
220
- def extract_final_channel_fallback(text: str) -> str:
221
- """Fallback extraction when harmony library isn't available."""
222
- # Look for the final channel marker
223
  final_marker = "<|channel|>final<|message|>"
224
-
225
  if final_marker in text:
226
  parts = text.split(final_marker)
227
  if len(parts) > 1:
228
  final_text = parts[-1]
229
-
230
- # Clean up end markers
231
- end_markers = ["<|return|>", "<|end|>", "<|endoftext|>"]
232
- for marker in end_markers:
233
  if marker in final_text:
234
  final_text = final_text.split(marker)[0]
235
-
236
  return final_text.strip()
237
 
238
- # If no channel markers found, return cleaned text
239
  return text.strip()
240
 
241
  # -----------------------
242
- # Rose guidance
243
  # -----------------------
244
-
245
- def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
246
- """Create vocab bias from {token: weight}."""
247
- vocab_size = len(tokenizer)
248
- bias = torch.zeros(vocab_size, dtype=torch.float32)
249
- for tok, w in mapping.items():
250
- if tok is None:
251
- continue
252
- tid = tokenizer.convert_tokens_to_ids(tok)
253
- if isinstance(tid, list):
254
- for t in tid:
255
- if isinstance(t, int) and t >= 0:
256
- bias[t] += float(w) / max(1, len(tid))
257
- elif isinstance(tid, int) and tid >= 0:
258
- bias[tid] += float(w)
259
- return bias
260
-
261
- class RoseGuidedLogits(torch.nn.Module):
262
- def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0):
263
- super().__init__()
264
- self.bias_vec = bias_vec
265
- self.alpha = float(alpha)
266
-
267
- def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
268
- return scores + self.alpha * self.bias_vec.to(scores.device)
269
-
270
- @spaces.GPU(duration=120)
271
- def zerogpu_generate(full_prompt: str,
272
- gen_kwargs: Dict[str, Any],
273
- rose_map: Optional[Dict[str, float]],
274
- rose_alpha: float,
275
- rose_score: Optional[float],
276
- seed: Optional[int]) -> Dict[str, str]:
277
- """Run inference on GPU and return parsed channels."""
278
  try:
279
- if seed is not None:
280
- torch.manual_seed(int(seed))
281
-
282
- # Load model
283
- model = _load_model_on("auto")
284
 
285
- # Setup logits processor for Rose guidance
286
- logits_processor = None
287
- if rose_map:
288
- bias = build_bias_from_tokens(tokenizer, rose_map).to(next(model.parameters()).device)
289
- eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
290
- logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
291
-
292
- # Tokenize input
293
- inputs = tokenizer(full_prompt, return_tensors="pt").to(next(model.parameters()).device)
294
 
295
  # Generate
296
- out_ids = model.generate(
297
- **inputs,
298
- do_sample=bool(gen_kwargs.get("do_sample", True)),
299
- temperature=float(gen_kwargs.get("temperature", 0.7)),
300
- top_p=float(gen_kwargs.get("top_p", 0.9)),
301
- top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
302
- max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
303
- pad_token_id=tokenizer.eos_token_id,
304
- eos_token_id=tokenizer.eos_token_id,
305
- logits_processor=logits_processor,
306
- )
307
 
308
- # Extract generated tokens only
309
- prompt_len = int(inputs["input_ids"].shape[1])
310
- gen_ids = out_ids[0][prompt_len:].tolist()
 
311
 
312
- # Parse response with Harmony
313
- if HARMONY_AVAILABLE:
314
- channels = parse_harmony_response(gen_ids)
315
- else:
316
- # Fallback
317
- decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
318
- channels = {
319
- "final": extract_final_channel_fallback(decoded),
320
- "raw": decoded
321
- }
322
 
323
- return channels
324
-
325
  except Exception as e:
326
- return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": str(e)}
 
 
327
  finally:
328
  # Cleanup
329
- try:
330
- del model
331
- except:
332
- pass
333
- gc.collect()
334
  if torch.cuda.is_available():
335
  torch.cuda.empty_cache()
 
 
 
 
 
336
 
337
  # -----------------------
338
- # Gradio handlers
339
  # -----------------------
340
-
341
- def generate_response(message: str, history: List[List[str]], system_prompt: str,
342
- temperature: float, top_p: float, top_k: int, max_new_tokens: int,
343
- do_sample: bool, seed: Optional[int],
344
- rose_enable: bool, rose_alpha: float, rose_score: Optional[float],
345
- rose_tokens: str, rose_json: str,
346
- show_thinking: bool = False,
347
- reasoning_effort: str = "high") -> str:
348
- """
349
- Generate response with proper CoT handling using Harmony format.
350
- """
 
351
  try:
352
- # Build message list
353
- messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
354
 
355
- # Add history
356
- if history:
357
- for turn in history:
358
- if isinstance(turn, (list, tuple)) and len(turn) >= 2:
359
- user_msg, assistant_msg = turn[0], turn[1]
360
- if user_msg:
361
- messages.append({"role": "user", "content": str(user_msg)})
362
- if assistant_msg:
363
- messages.append({"role": "assistant", "content": str(assistant_msg)})
364
 
365
- # Add current message
366
- messages.append({"role": "user", "content": str(message)})
367
 
368
- # Create Harmony-formatted prompt
369
- if HARMONY_AVAILABLE:
370
- prompt = create_harmony_prompt(messages, reasoning_effort)
371
- else:
372
- # Fallback to tokenizer template
373
- prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
374
-
375
- # Build Rose map if enabled
376
- rose_map: Optional[Dict[str, float]] = None
377
- if rose_enable:
378
- rose_map = {}
379
- tok_str = (rose_tokens or "").strip()
380
- if tok_str:
381
- for p in [p.strip() for p in tok_str.split(",") if p.strip()]:
382
- if ":" in p:
383
- k, v = p.split(":", 1)
384
- try:
385
- rose_map[k.strip()] = float(v)
386
- except:
387
- pass
388
- if rose_json:
389
- try:
390
- j = json.loads(rose_json)
391
- if isinstance(j, dict):
392
- for k, v in j.items():
393
- try:
394
- rose_map[str(k)] = float(v)
395
- except:
396
- pass
397
- except:
398
- pass
399
- if not rose_map:
400
- rose_map = None
401
-
402
- # Generate with model
403
- channels = zerogpu_generate(
404
- prompt,
405
- {
406
- "do_sample": bool(do_sample),
407
- "temperature": float(temperature),
408
- "top_p": float(top_p),
409
- "top_k": int(top_k) if top_k > 0 else None,
410
- "max_new_tokens": int(max_new_tokens),
411
- },
412
- rose_map,
413
- float(rose_alpha),
414
- float(rose_score) if rose_score is not None else None,
415
- int(seed) if seed is not None else None,
416
  )
417
 
418
- # Format response
419
  if show_thinking:
420
- # Show all channels
421
- response = "## Chain of Thought:\n\n"
422
- for channel, content in channels.items():
423
- if channel != "final" and content:
424
- response += f"### {channel.capitalize()} Channel:\n{content}\n\n"
425
- response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
426
- return response
427
  else:
428
- # Just show the final response
429
- return channels.get("final", "No final response generated")
430
 
431
  except Exception as e:
432
- return f"[Error] {type(e).__name__}: {str(e)}"
 
 
433
 
434
  # -----------------------
435
- # UI
436
  # -----------------------
437
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
438
- gr.Markdown(
439
- """
440
- # Mirel – Harmony Chain-of-Thought Inference
441
-
442
- OSS-20B model using Harmony format with thinking channels.
443
- The model thinks through problems in internal channels before providing a final response.
444
-
445
- **Note:** Install `openai-harmony` for full Harmony support: `pip install openai-harmony`
446
- """
447
- )
448
-
449
- with gr.Row():
450
- system_prompt = gr.Textbox(
451
- label="System Prompt",
452
- value=SYSTEM_DEF,
453
- lines=2
454
- )
455
 
456
- with gr.Accordion("Generation Settings", open=False):
457
- with gr.Row():
458
- temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
459
- top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
460
- top_k = gr.Slider(0, 200, value=0, step=1, label="Top-k (0=disabled)")
461
- with gr.Row():
462
- max_new = gr.Slider(16, 4096, value=MAX_DEF, step=16, label="Max new tokens")
463
- do_sample = gr.Checkbox(value=True, label="Do sample")
464
- seed = gr.Number(value=None, label="Seed (optional)", precision=0)
465
  with gr.Row():
466
- reasoning_effort = gr.Radio(
467
- choices=["low", "medium", "high"],
468
- value="high",
469
- label="Reasoning Effort",
470
- info="How much thinking the model should do"
471
- )
472
- show_thinking = gr.Checkbox(
473
- value=False,
474
- label="Show thinking channels",
475
- info="Display all internal reasoning channels"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  )
477
-
478
- with gr.Accordion("Rose Guidance (Optional)", open=False):
479
- gr.Markdown("Fine-tune generation with token biases")
480
- with gr.Row():
481
- rose_enable = gr.Checkbox(value=False, label="Enable Rose bias")
482
- rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="Alpha (strength)")
483
- rose_score = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Score multiplier")
484
- rose_tokens = gr.Textbox(
485
- label="Token:weight pairs",
486
- placeholder="example:1.5, test:-0.5",
487
- value=""
 
488
  )
489
- rose_json = gr.Textbox(
490
- label="JSON weights",
491
- placeholder='{"token": 1.0, "another": -0.5}',
492
- value=""
 
 
 
 
 
 
493
  )
494
-
495
- # Chat interface - using only valid parameters
496
- chat = gr.ChatInterface(
497
- fn=generate_response,
498
- additional_inputs=[
499
- system_prompt, temperature, top_p, top_k, max_new,
500
- do_sample, seed, rose_enable, rose_alpha, rose_score,
501
- rose_tokens, rose_json, show_thinking, reasoning_effort
502
- ],
503
- title="Chat with Mirel",
504
- description="A chain-of-thought model using Harmony format",
505
- examples=[
506
- ["Hello! Can you introduce yourself?"],
507
- ["What is the capital of France?"],
508
- ["Explain quantum computing in simple terms"],
509
- ["Solve: If a train travels 120 miles in 2 hours, what is its average speed?"],
510
- ],
511
- cache_examples=False,
512
- )
513
-
514
- gr.Markdown(
515
- """
516
- ---
517
- ### Configuration:
518
- - **Model**: Set `MODEL_ID` env var (default: openai/gpt-oss-20b)
519
- - **Adapter**: Set `ADAPTER_ID` and optionally `ADAPTER_SUBFOLDER`
520
- - **Auth**: Set `HF_TOKEN` in Space secrets for private model access
521
- - **Harmony**: Install with `pip install openai-harmony` for proper channel support
522
 
523
- The model uses Harmony format with thinking channels (`thinking`, `analysis`, `final`).
524
- """
525
- )
526
 
 
 
 
527
  if __name__ == "__main__":
528
- demo.queue(max_size=8 if ZEROGPU else 32).launch(
529
- server_name="0.0.0.0",
 
 
530
  server_port=7860,
531
- share=False
532
  )
 
1
  """
2
  Mirel Harmony Inference – HF Space (Gradio)
3
+ Simplified version with robust error handling
 
 
4
  """
5
+ import os
6
+ import gc
7
+ import json
8
+ import torch
 
9
  import gradio as gr
10
+ from typing import List, Dict, Optional, Any, Generator
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
+ # Check if spaces is available
14
  try:
15
+ import spaces
16
+ SPACES_AVAILABLE = True
 
 
 
 
 
 
 
 
 
 
17
  except ImportError:
18
+ SPACES_AVAILABLE = False
19
+ print("[WARNING] spaces not available, running without ZeroGPU")
20
 
21
  # -----------------------
22
+ # Config
23
  # -----------------------
24
+ MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
25
+ ADAPTER_ID = os.getenv("ADAPTER_ID")
26
+ ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER")
27
+ SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are Mirel, a helpful assistant.")
28
+ MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
29
+ DTYPE = os.getenv("DTYPE", "bf16")
30
+ ZEROGPU = os.getenv("ZEROGPU", "0") == "1"
31
+
32
+ # HF Token
33
+ HF_TOKEN = (
 
 
 
 
 
 
 
34
  os.getenv("HF_TOKEN")
35
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
36
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
37
  )
38
 
39
+ if HF_TOKEN:
40
+ try:
41
+ from huggingface_hub import login
42
+ login(token=HF_TOKEN)
43
+ print("[Auth] Logged in to Hugging Face")
44
+ except Exception as e:
45
+ print(f"[Auth] Failed to login: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # -----------------------
48
+ # Model Loading
49
  # -----------------------
50
+ print(f"[Model] Loading tokenizer from {MODEL_ID}")
51
+ tokenizer = AutoTokenizer.from_pretrained(
52
+ MODEL_ID,
53
+ trust_remote_code=True,
54
+ token=HF_TOKEN
55
+ )
56
 
57
+ model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ def get_dtype():
60
+ """Get the appropriate dtype for the model."""
61
+ if DTYPE == "bf16" and torch.cuda.is_available():
62
+ return torch.bfloat16
63
+ elif DTYPE == "fp16":
64
+ return torch.float16
65
+ else:
66
+ return torch.float32
67
 
68
+ def load_model():
69
+ """Load the model (called inside GPU context if using ZeroGPU)."""
70
+ global model
71
+ if model is None:
72
+ print(f"[Model] Loading model from {MODEL_ID}")
73
+
74
+ kwargs = {
75
+ "torch_dtype": get_dtype(),
76
+ "device_map": "auto" if torch.cuda.is_available() else "cpu",
77
+ "trust_remote_code": True,
78
+ "token": HF_TOKEN,
79
+ "low_cpu_mem_usage": True,
80
+ }
81
+
82
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
83
+
84
+ # Load adapter if specified
85
+ if ADAPTER_ID:
86
+ try:
87
+ from peft import PeftModel
88
+ print(f"[Model] Loading adapter from {ADAPTER_ID}")
89
+ adapter_kwargs = {"token": HF_TOKEN}
90
+ if ADAPTER_SUBFOLDER:
91
+ adapter_kwargs["subfolder"] = ADAPTER_SUBFOLDER
92
+ model = PeftModel.from_pretrained(
93
+ model,
94
+ ADAPTER_ID,
95
+ is_trainable=False,
96
+ **adapter_kwargs
97
+ )
98
+ except ImportError:
99
+ print("[WARNING] PEFT not installed, skipping adapter")
100
+ except Exception as e:
101
+ print(f"[WARNING] Failed to load adapter: {e}")
102
+
103
+ model.eval()
104
  return model
105
 
106
+ def extract_final_response(text: str) -> str:
107
+ """Extract the final channel from chain-of-thought output."""
108
+ # Look for final channel marker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  final_marker = "<|channel|>final<|message|>"
 
110
  if final_marker in text:
111
  parts = text.split(final_marker)
112
  if len(parts) > 1:
113
  final_text = parts[-1]
114
+ # Clean end markers
115
+ for marker in ["<|return|>", "<|end|>", "<|endoftext|>"]:
 
 
116
  if marker in final_text:
117
  final_text = final_text.split(marker)[0]
 
118
  return final_text.strip()
119
 
120
+ # No channel markers, return cleaned text
121
  return text.strip()
122
 
123
  # -----------------------
124
+ # Generation Function
125
  # -----------------------
126
+ def generate_text(
127
+ prompt: str,
128
+ temperature: float = 0.7,
129
+ top_p: float = 0.9,
130
+ top_k: int = 0,
131
+ max_new_tokens: int = 512,
132
+ do_sample: bool = True,
133
+ ) -> str:
134
+ """Generate text using the model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  try:
136
+ # Load/get model
137
+ model_instance = load_model()
 
 
 
138
 
139
+ # Tokenize
140
+ inputs = tokenizer(prompt, return_tensors="pt")
141
+ if torch.cuda.is_available():
142
+ inputs = inputs.to("cuda")
 
 
 
 
 
143
 
144
  # Generate
145
+ with torch.no_grad():
146
+ outputs = model_instance.generate(
147
+ **inputs,
148
+ max_new_tokens=max_new_tokens,
149
+ temperature=temperature,
150
+ top_p=top_p,
151
+ top_k=top_k if top_k > 0 else None,
152
+ do_sample=do_sample,
153
+ pad_token_id=tokenizer.eos_token_id,
154
+ eos_token_id=tokenizer.eos_token_id,
155
+ )
156
 
157
+ # Decode
158
+ prompt_len = inputs["input_ids"].shape[1]
159
+ generated_ids = outputs[0][prompt_len:]
160
+ response = tokenizer.decode(generated_ids, skip_special_tokens=False)
161
 
162
+ return response
 
 
 
 
 
 
 
 
 
163
 
 
 
164
  except Exception as e:
165
+ error_msg = f"Generation error: {str(e)}"
166
+ print(f"[ERROR] {error_msg}")
167
+ return error_msg
168
  finally:
169
  # Cleanup
 
 
 
 
 
170
  if torch.cuda.is_available():
171
  torch.cuda.empty_cache()
172
+ gc.collect()
173
+
174
+ # Add GPU decorator if available
175
+ if SPACES_AVAILABLE and ZEROGPU:
176
+ generate_text = spaces.GPU(duration=120)(generate_text)
177
 
178
  # -----------------------
179
+ # Chat Function
180
  # -----------------------
181
+ def chat_fn(
182
+ message: str,
183
+ history: List[List[str]],
184
+ system_prompt: str,
185
+ temperature: float,
186
+ top_p: float,
187
+ top_k: int,
188
+ max_new_tokens: int,
189
+ do_sample: bool,
190
+ show_thinking: bool,
191
+ ) -> str:
192
+ """Main chat function for Gradio."""
193
  try:
194
+ # Build conversation
195
+ messages = [{"role": "system", "content": system_prompt or SYSTEM_PROMPT}]
196
 
197
+ for user_msg, assistant_msg in (history or []):
198
+ if user_msg:
199
+ messages.append({"role": "user", "content": user_msg})
200
+ if assistant_msg:
201
+ messages.append({"role": "assistant", "content": assistant_msg})
 
 
 
 
202
 
203
+ messages.append({"role": "user", "content": message})
 
204
 
205
+ # Apply chat template
206
+ try:
207
+ prompt = tokenizer.apply_chat_template(
208
+ messages,
209
+ add_generation_prompt=True,
210
+ tokenize=False
211
+ )
212
+ except Exception:
213
+ # Fallback to simple format
214
+ prompt = f"{system_prompt}\n\n"
215
+ for msg in messages[1:]:
216
+ role = msg["role"].upper()
217
+ content = msg["content"]
218
+ prompt += f"{role}: {content}\n"
219
+ prompt += "ASSISTANT: "
220
+
221
+ # Generate response
222
+ full_response = generate_text(
223
+ prompt=prompt,
224
+ temperature=temperature,
225
+ top_p=top_p,
226
+ top_k=int(top_k),
227
+ max_new_tokens=int(max_new_tokens),
228
+ do_sample=do_sample,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  )
230
 
231
+ # Process response
232
  if show_thinking:
233
+ # Show full output with channels
234
+ final = extract_final_response(full_response)
235
+ return f"**Full Output:**\n```\n{full_response}\n```\n\n**Final Response:**\n{final}"
 
 
 
 
236
  else:
237
+ # Just show final response
238
+ return extract_final_response(full_response)
239
 
240
  except Exception as e:
241
+ error_msg = f"Chat error: {str(e)}"
242
+ print(f"[ERROR] {error_msg}")
243
+ return error_msg
244
 
245
  # -----------------------
246
+ # Gradio Interface
247
  # -----------------------
248
+ def create_interface():
249
+ """Create the Gradio interface."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
 
251
+ with gr.Blocks(title="Mirel Chat") as demo:
252
+ gr.Markdown(
253
+ """
254
+ # Mirel - Chain-of-Thought Chat
255
+
256
+ Chat with a model that thinks before responding.
257
+ """
258
+ )
259
+
260
  with gr.Row():
261
+ with gr.Column(scale=4):
262
+ chatbot = gr.Chatbot(height=500)
263
+ msg = gr.Textbox(
264
+ label="Message",
265
+ placeholder="Type your message here...",
266
+ lines=2
267
+ )
268
+ with gr.Row():
269
+ submit = gr.Button("Send", variant="primary")
270
+ clear = gr.Button("Clear")
271
+
272
+ with gr.Column(scale=1):
273
+ system_prompt = gr.Textbox(
274
+ label="System Prompt",
275
+ value=SYSTEM_PROMPT,
276
+ lines=3
277
+ )
278
+
279
+ with gr.Accordion("Settings", open=False):
280
+ temperature = gr.Slider(
281
+ minimum=0.1,
282
+ maximum=2.0,
283
+ value=0.7,
284
+ step=0.1,
285
+ label="Temperature"
286
+ )
287
+ top_p = gr.Slider(
288
+ minimum=0.1,
289
+ maximum=1.0,
290
+ value=0.9,
291
+ step=0.1,
292
+ label="Top-p"
293
+ )
294
+ top_k = gr.Slider(
295
+ minimum=0,
296
+ maximum=100,
297
+ value=0,
298
+ step=1,
299
+ label="Top-k (0=disabled)"
300
+ )
301
+ max_new_tokens = gr.Slider(
302
+ minimum=64,
303
+ maximum=2048,
304
+ value=MAX_NEW_TOKENS,
305
+ step=64,
306
+ label="Max Tokens"
307
+ )
308
+ do_sample = gr.Checkbox(
309
+ value=True,
310
+ label="Do Sample"
311
+ )
312
+ show_thinking = gr.Checkbox(
313
+ value=False,
314
+ label="Show Thinking Process"
315
+ )
316
+
317
+ # Event handlers
318
+ def user_submit(message, history):
319
+ return "", history + [[message, None]]
320
+
321
+ def bot_respond(history, system, temp, top_p, top_k, max_tokens, sample, thinking):
322
+ if not history or not history[-1][0]:
323
+ return history
324
+
325
+ user_message = history[-1][0]
326
+ bot_message = chat_fn(
327
+ user_message,
328
+ history[:-1], # Don't include current turn
329
+ system,
330
+ temp,
331
+ top_p,
332
+ top_k,
333
+ max_tokens,
334
+ sample,
335
+ thinking
336
  )
337
+ history[-1][1] = bot_message
338
+ return history
339
+
340
+ msg.submit(
341
+ user_submit,
342
+ [msg, chatbot],
343
+ [msg, chatbot],
344
+ queue=False
345
+ ).then(
346
+ bot_respond,
347
+ [chatbot, system_prompt, temperature, top_p, top_k, max_new_tokens, do_sample, show_thinking],
348
+ chatbot
349
  )
350
+
351
+ submit.click(
352
+ user_submit,
353
+ [msg, chatbot],
354
+ [msg, chatbot],
355
+ queue=False
356
+ ).then(
357
+ bot_respond,
358
+ [chatbot, system_prompt, temperature, top_p, top_k, max_new_tokens, do_sample, show_thinking],
359
+ chatbot
360
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
+ clear.click(lambda: None, None, chatbot, queue=False)
363
+
364
+ return demo
365
 
366
+ # -----------------------
367
+ # Main
368
+ # -----------------------
369
  if __name__ == "__main__":
370
+ demo = create_interface()
371
+ demo.queue(max_size=10)
372
+ demo.launch(
373
+ server_name="0.0.0.0",
374
  server_port=7860,
375
+ share=True
376
  )