AbstractPhil commited on
Commit
d76bf3e
·
1 Parent(s): 4fc7c90
Files changed (1) hide show
  1. app.py +457 -301
app.py CHANGED
@@ -1,376 +1,532 @@
1
  """
2
  Mirel Harmony Inference – HF Space (Gradio)
3
- Simplified version with robust error handling
 
 
4
  """
5
- import os
6
- import gc
7
- import json
8
- import torch
 
9
  import gradio as gr
10
- from typing import List, Dict, Optional, Any, Generator
11
  from transformers import AutoTokenizer, AutoModelForCausalLM
12
 
13
- # Check if spaces is available
14
  try:
15
- import spaces
16
- SPACES_AVAILABLE = True
 
 
 
 
 
 
 
 
 
 
17
  except ImportError:
18
- SPACES_AVAILABLE = False
19
- print("[WARNING] spaces not available, running without ZeroGPU")
20
 
21
  # -----------------------
22
- # Config
23
  # -----------------------
24
- MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
25
- ADAPTER_ID = os.getenv("ADAPTER_ID")
26
- ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER")
27
- SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT", "You are Mirel, a helpful assistant.")
28
- MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", "512"))
29
- DTYPE = os.getenv("DTYPE", "bf16")
30
- ZEROGPU = os.getenv("ZEROGPU", "0") == "1"
31
-
32
- # HF Token
33
- HF_TOKEN = (
 
 
 
 
 
 
 
34
  os.getenv("HF_TOKEN")
35
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
36
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
 
37
  )
38
 
39
- if HF_TOKEN:
40
- try:
41
- from huggingface_hub import login
42
- login(token=HF_TOKEN)
43
- print("[Auth] Logged in to Hugging Face")
44
- except Exception as e:
45
- print(f"[Auth] Failed to login: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  # -----------------------
48
- # Model Loading
49
  # -----------------------
50
- print(f"[Model] Loading tokenizer from {MODEL_ID}")
51
- tokenizer = AutoTokenizer.from_pretrained(
52
- MODEL_ID,
53
- trust_remote_code=True,
54
- token=HF_TOKEN
55
- )
56
 
57
- model = None
58
 
59
- def get_dtype():
60
- """Get the appropriate dtype for the model."""
61
- if DTYPE == "bf16" and torch.cuda.is_available():
62
- return torch.bfloat16
63
- elif DTYPE == "fp16":
64
- return torch.float16
65
- else:
66
- return torch.float32
 
 
 
 
 
 
 
 
 
 
67
 
68
- def load_model():
69
- """Load the model (called inside GPU context if using ZeroGPU)."""
70
- global model
71
- if model is None:
72
- print(f"[Model] Loading model from {MODEL_ID}")
73
-
74
- kwargs = {
75
- "torch_dtype": get_dtype(),
76
- "device_map": "auto" if torch.cuda.is_available() else "cpu",
77
- "trust_remote_code": True,
78
- "token": HF_TOKEN,
79
- "low_cpu_mem_usage": True,
80
- }
81
-
82
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **kwargs)
83
-
84
- # Load adapter if specified
85
- if ADAPTER_ID:
86
- try:
87
- from peft import PeftModel
88
- print(f"[Model] Loading adapter from {ADAPTER_ID}")
89
- adapter_kwargs = {"token": HF_TOKEN}
90
- if ADAPTER_SUBFOLDER:
91
- adapter_kwargs["subfolder"] = ADAPTER_SUBFOLDER
92
- model = PeftModel.from_pretrained(
93
- model,
94
- ADAPTER_ID,
95
- is_trainable=False,
96
- **adapter_kwargs
97
- )
98
- except ImportError:
99
- print("[WARNING] PEFT not installed, skipping adapter")
100
- except Exception as e:
101
- print(f"[WARNING] Failed to load adapter: {e}")
102
-
103
- model.eval()
104
  return model
105
 
106
- def extract_final_response(text: str) -> str:
107
- """Extract the final channel from chain-of-thought output."""
108
- # Look for final channel marker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  final_marker = "<|channel|>final<|message|>"
 
110
  if final_marker in text:
111
  parts = text.split(final_marker)
112
  if len(parts) > 1:
113
  final_text = parts[-1]
114
- # Clean end markers
115
- for marker in ["<|return|>", "<|end|>", "<|endoftext|>"]:
 
 
116
  if marker in final_text:
117
  final_text = final_text.split(marker)[0]
 
118
  return final_text.strip()
119
 
120
- # No channel markers, return cleaned text
121
  return text.strip()
122
 
123
  # -----------------------
124
- # Generation Function
125
  # -----------------------
126
- def generate_text(
127
- prompt: str,
128
- temperature: float = 0.7,
129
- top_p: float = 0.9,
130
- top_k: int = 0,
131
- max_new_tokens: int = 512,
132
- do_sample: bool = True,
133
- ) -> str:
134
- """Generate text using the model."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  try:
136
- # Load/get model
137
- model_instance = load_model()
 
 
 
138
 
139
- # Tokenize
140
- inputs = tokenizer(prompt, return_tensors="pt")
141
- if torch.cuda.is_available():
142
- inputs = inputs.to("cuda")
 
 
 
 
 
143
 
144
  # Generate
145
- with torch.no_grad():
146
- outputs = model_instance.generate(
147
- **inputs,
148
- max_new_tokens=max_new_tokens,
149
- temperature=temperature,
150
- top_p=top_p,
151
- top_k=top_k if top_k > 0 else None,
152
- do_sample=do_sample,
153
- pad_token_id=tokenizer.eos_token_id,
154
- eos_token_id=tokenizer.eos_token_id,
155
- )
156
 
157
- # Decode
158
- prompt_len = inputs["input_ids"].shape[1]
159
- generated_ids = outputs[0][prompt_len:]
160
- response = tokenizer.decode(generated_ids, skip_special_tokens=False)
161
 
162
- return response
 
 
 
 
 
 
 
 
 
163
 
 
 
164
  except Exception as e:
165
- error_msg = f"Generation error: {str(e)}"
166
- print(f"[ERROR] {error_msg}")
167
- return error_msg
168
  finally:
169
  # Cleanup
 
 
 
 
 
170
  if torch.cuda.is_available():
171
  torch.cuda.empty_cache()
172
- gc.collect()
173
-
174
- # Add GPU decorator if available
175
- if SPACES_AVAILABLE and ZEROGPU:
176
- generate_text = spaces.GPU(duration=120)(generate_text)
177
 
178
  # -----------------------
179
- # Chat Function
180
  # -----------------------
181
- def chat_fn(
182
- message: str,
183
- history: List[List[str]],
184
- system_prompt: str,
185
- temperature: float,
186
- top_p: float,
187
- top_k: int,
188
- max_new_tokens: int,
189
- do_sample: bool,
190
- show_thinking: bool,
191
- ) -> str:
192
- """Main chat function for Gradio."""
193
  try:
194
- # Build conversation
195
- messages = [{"role": "system", "content": system_prompt or SYSTEM_PROMPT}]
196
 
197
- for user_msg, assistant_msg in (history or []):
198
- if user_msg:
199
- messages.append({"role": "user", "content": user_msg})
200
- if assistant_msg:
201
- messages.append({"role": "assistant", "content": assistant_msg})
 
 
 
 
202
 
203
- messages.append({"role": "user", "content": message})
 
204
 
205
- # Apply chat template
206
- try:
207
- prompt = tokenizer.apply_chat_template(
208
- messages,
209
- add_generation_prompt=True,
210
- tokenize=False
211
- )
212
- except Exception:
213
- # Fallback to simple format
214
- prompt = f"{system_prompt}\n\n"
215
- for msg in messages[1:]:
216
- role = msg["role"].upper()
217
- content = msg["content"]
218
- prompt += f"{role}: {content}\n"
219
- prompt += "ASSISTANT: "
220
-
221
- # Generate response
222
- full_response = generate_text(
223
- prompt=prompt,
224
- temperature=temperature,
225
- top_p=top_p,
226
- top_k=int(top_k),
227
- max_new_tokens=int(max_new_tokens),
228
- do_sample=do_sample,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  )
230
 
231
- # Process response
232
  if show_thinking:
233
- # Show full output with channels
234
- final = extract_final_response(full_response)
235
- return f"**Full Output:**\n```\n{full_response}\n```\n\n**Final Response:**\n{final}"
 
 
 
 
236
  else:
237
- # Just show final response
238
- return extract_final_response(full_response)
239
 
240
  except Exception as e:
241
- error_msg = f"Chat error: {str(e)}"
242
- print(f"[ERROR] {error_msg}")
243
- return error_msg
244
 
245
  # -----------------------
246
- # Gradio Interface
247
  # -----------------------
248
- def create_interface():
249
- """Create the Gradio interface."""
250
-
251
- with gr.Blocks(title="Mirel Chat") as demo:
252
- gr.Markdown(
253
- """
254
- # Mirel - Chain-of-Thought Chat
255
-
256
- Chat with a model that thinks before responding.
257
- """
258
- )
259
-
260
- with gr.Row():
261
- with gr.Column(scale=4):
262
- chatbot = gr.Chatbot(height=500)
263
- msg = gr.Textbox(
264
- label="Message",
265
- placeholder="Type your message here...",
266
- lines=2
267
- )
268
- with gr.Row():
269
- submit = gr.Button("Send", variant="primary")
270
- clear = gr.Button("Clear")
271
-
272
- with gr.Column(scale=1):
273
- system_prompt = gr.Textbox(
274
- label="System Prompt",
275
- value=SYSTEM_PROMPT,
276
- lines=3
277
- )
278
-
279
- with gr.Accordion("Settings", open=False):
280
- temperature = gr.Slider(
281
- minimum=0.1,
282
- maximum=2.0,
283
- value=0.7,
284
- step=0.1,
285
- label="Temperature"
286
- )
287
- top_p = gr.Slider(
288
- minimum=0.1,
289
- maximum=1.0,
290
- value=0.9,
291
- step=0.1,
292
- label="Top-p"
293
- )
294
- top_k = gr.Slider(
295
- minimum=0,
296
- maximum=100,
297
- value=0,
298
- step=1,
299
- label="Top-k (0=disabled)"
300
- )
301
- max_new_tokens = gr.Slider(
302
- minimum=64,
303
- maximum=2048,
304
- value=MAX_NEW_TOKENS,
305
- step=64,
306
- label="Max Tokens"
307
- )
308
- do_sample = gr.Checkbox(
309
- value=True,
310
- label="Do Sample"
311
- )
312
- show_thinking = gr.Checkbox(
313
- value=False,
314
- label="Show Thinking Process"
315
- )
316
 
317
- # Event handlers
318
- def user_submit(message, history):
319
- return "", history + [[message, None]]
320
 
321
- def bot_respond(history, system, temp, top_p, top_k, max_tokens, sample, thinking):
322
- if not history or not history[-1][0]:
323
- return history
324
-
325
- user_message = history[-1][0]
326
- bot_message = chat_fn(
327
- user_message,
328
- history[:-1], # Don't include current turn
329
- system,
330
- temp,
331
- top_p,
332
- top_k,
333
- max_tokens,
334
- sample,
335
- thinking
 
 
 
 
 
 
 
 
 
 
 
336
  )
337
- history[-1][1] = bot_message
338
- return history
339
-
340
- msg.submit(
341
- user_submit,
342
- [msg, chatbot],
343
- [msg, chatbot],
344
- queue=False
345
- ).then(
346
- bot_respond,
347
- [chatbot, system_prompt, temperature, top_p, top_k, max_new_tokens, do_sample, show_thinking],
348
- chatbot
 
 
 
 
349
  )
350
-
351
- submit.click(
352
- user_submit,
353
- [msg, chatbot],
354
- [msg, chatbot],
355
- queue=False
356
- ).then(
357
- bot_respond,
358
- [chatbot, system_prompt, temperature, top_p, top_k, max_new_tokens, do_sample, show_thinking],
359
- chatbot
360
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
- clear.click(lambda: None, None, chatbot, queue=False)
363
-
364
- return demo
365
 
366
- # -----------------------
367
- # Main
368
- # -----------------------
369
  if __name__ == "__main__":
370
- demo = create_interface()
371
- demo.queue(max_size=10)
372
- demo.launch(
373
- server_name="0.0.0.0",
374
  server_port=7860,
375
- share=False
376
  )
 
1
  """
2
  Mirel Harmony Inference – HF Space (Gradio)
3
+ ZeroGPU-ready, Harmony formatting, optional Rose-guided decoding
4
+ Chain-of-thought model with proper channel extraction using openai_harmony
5
+ Single file: app.py
6
  """
7
+ from __future__ import annotations
8
+ import os, gc, json, threading, torch
9
+ from dataclasses import dataclass
10
+ from typing import List, Dict, Optional, Any
11
+ from datetime import datetime
12
  import gradio as gr
13
+ import spaces # required for ZeroGPU
14
  from transformers import AutoTokenizer, AutoModelForCausalLM
15
 
16
+ # Import Harmony components
17
  try:
18
+ from openai_harmony import (
19
+ Author,
20
+ Conversation,
21
+ HarmonyEncodingName,
22
+ Message,
23
+ Role,
24
+ SystemContent,
25
+ DeveloperContent,
26
+ load_harmony_encoding,
27
+ ReasoningEffort
28
+ )
29
+ HARMONY_AVAILABLE = True
30
  except ImportError:
31
+ print("[WARNING] openai_harmony not installed. Install with: pip install openai-harmony")
32
+ HARMONY_AVAILABLE = False
33
 
34
  # -----------------------
35
+ # Config & runtime modes
36
  # -----------------------
37
+ DTYPE_MAP = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}
38
+
39
+ MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
40
+ ADAPTER_ID = os.getenv("ADAPTER_ID") or None
41
+ ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER") or None
42
+ ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
43
+ DTYPE = DTYPE_MAP.get(os.getenv("DTYPE", "bf16").lower(), torch.bfloat16)
44
+ SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
45
+ MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "1024"))
46
+ ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
47
+ LOAD_4BIT = os.getenv("LOAD_4BIT", "0") == "1"
48
+
49
+ # Harmony channels for CoT
50
+ REQUIRED_CHANNELS = ["thinking", "analysis", "final"]
51
+
52
+ # HF Auth - properly handle multiple token env var names
53
+ HF_TOKEN: Optional[str] = (
54
  os.getenv("HF_TOKEN")
55
  or os.getenv("HUGGING_FACE_HUB_TOKEN")
56
  or os.getenv("HUGGINGFACEHUB_API_TOKEN")
57
+ or os.getenv("HF_ACCESS_TOKEN")
58
  )
59
 
60
+ def _hf_login() -> None:
61
+ """Login to HF Hub using common env secret names."""
62
+ if HF_TOKEN:
63
+ try:
64
+ from huggingface_hub import login, whoami
65
+ login(token=HF_TOKEN, add_to_git_credential=True)
66
+ try:
67
+ who = whoami(token=HF_TOKEN)
68
+ print(f"[HF Auth] Logged in as: {who.get('name') or who.get('fullname') or who.get('id', 'unknown')}")
69
+ except Exception:
70
+ print("[HF Auth] Login successful but couldn't get user info")
71
+ except Exception as e:
72
+ print(f"[HF Auth] Login failed: {e}")
73
+ else:
74
+ print("[HF Auth] No token found in environment variables")
75
+
76
+ # Login before loading any models
77
+ _hf_login()
78
+
79
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
80
+
81
+ # Load Harmony encoding if available
82
+ if HARMONY_AVAILABLE:
83
+ harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
84
+ else:
85
+ harmony_encoding = None
86
+
87
+ # Tokenizer is lightweight; load once
88
+ try:
89
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
90
+ print(f"[Model] Successfully loaded tokenizer from {MODEL_ID}")
91
+ except Exception as e:
92
+ print(f"[Model] Failed to load tokenizer: {e}")
93
+ raise
94
 
95
  # -----------------------
96
+ # Model loading
97
  # -----------------------
98
+ try:
99
+ from peft import PeftModel
100
+ _HAS_PEFT = True
101
+ except Exception:
102
+ _HAS_PEFT = False
 
103
 
 
104
 
105
+ def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
106
+ kw: Dict[str, Any] = dict(
107
+ torch_dtype=DTYPE,
108
+ device_map=device_map,
109
+ attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
110
+ trust_remote_code=True,
111
+ low_cpu_mem_usage=True,
112
+ token=HF_TOKEN,
113
+ )
114
+ if LOAD_4BIT and device_map != "cpu":
115
+ try:
116
+ import bitsandbytes as _bnb
117
+ kw.update(load_in_4bit=True)
118
+ if kw["device_map"] is None:
119
+ kw["device_map"] = "auto"
120
+ except Exception:
121
+ pass
122
+ return kw
123
 
124
+
125
+ def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
126
+ print(f"[Model] Loading base model from {MODEL_ID}...")
127
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
128
+
129
+ if ADAPTER_ID:
130
+ if not _HAS_PEFT:
131
+ raise RuntimeError("peft is required when ADAPTER_ID is set.")
132
+ print(f"[Model] Loading adapter from {ADAPTER_ID}...")
133
+ peft_kwargs: Dict[str, Any] = {"token": HF_TOKEN}
134
+ if ADAPTER_SUBFOLDER:
135
+ peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
136
+ model = PeftModel.from_pretrained(model, ADAPTER_ID, is_trainable=False, **peft_kwargs)
137
+
138
+ model.eval()
139
+ model.config.use_cache = True
140
+ print("[Model] Model loaded successfully")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  return model
142
 
143
+ # -----------------------
144
+ # Harmony formatting
145
+ # -----------------------
146
+
147
+ def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> str:
148
+ """Create a proper Harmony-formatted prompt using openai_harmony."""
149
+ if not HARMONY_AVAILABLE:
150
+ # Fallback to tokenizer's chat template
151
+ return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
152
+
153
+ # Map reasoning effort
154
+ effort_map = {
155
+ "low": ReasoningEffort.LOW,
156
+ "medium": ReasoningEffort.MEDIUM,
157
+ "high": ReasoningEffort.HIGH,
158
+ }
159
+ effort = effort_map.get(reasoning_effort.lower(), ReasoningEffort.HIGH)
160
+
161
+ # Create system message with channels
162
+ system_content = (
163
+ SystemContent.new()
164
+ .with_model_identity(messages[0]["content"] if messages else SYSTEM_DEF)
165
+ .with_reasoning_effort(effort)
166
+ .with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
167
+ .with_knowledge_cutoff("2025-01")
168
+ .with_required_channels(REQUIRED_CHANNELS)
169
+ )
170
+
171
+ # Build conversation
172
+ harmony_messages = [
173
+ Message.from_role_and_content(Role.SYSTEM, system_content)
174
+ ]
175
+
176
+ # Add user/assistant messages
177
+ for msg in messages[1:]: # Skip system message as we already added it
178
+ if msg["role"] == "user":
179
+ harmony_messages.append(
180
+ Message.from_role_and_content(Role.USER, msg["content"])
181
+ )
182
+ elif msg["role"] == "assistant":
183
+ # For assistant messages, we might want to preserve channels if they exist
184
+ harmony_messages.append(
185
+ Message.from_role_and_content(Role.ASSISTANT, msg["content"])
186
+ .with_channel("final") # Default to final channel
187
+ )
188
+
189
+ # Create conversation and render
190
+ convo = Conversation.from_messages(harmony_messages)
191
+ tokens = harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
192
+
193
+ # Convert tokens back to text for the model
194
+ return tokenizer.decode(tokens)
195
+
196
+ def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
197
+ """Parse response tokens using Harmony format to extract channels."""
198
+ if not HARMONY_AVAILABLE:
199
+ # Fallback: just decode and extract final channel manually
200
+ text = tokenizer.decode(tokens, skip_special_tokens=False)
201
+ return {"final": extract_final_channel_fallback(text), "raw": text}
202
+
203
+ # Parse messages from completion tokens
204
+ parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
205
+
206
+ # Extract content by channel
207
+ channels = {}
208
+ for msg in parsed_messages:
209
+ channel = msg.channel if hasattr(msg, 'channel') else "final"
210
+ if channel not in channels:
211
+ channels[channel] = ""
212
+ channels[channel] += msg.content
213
+
214
+ # Ensure we have a final channel
215
+ if "final" not in channels:
216
+ channels["final"] = " ".join(channels.values())
217
+
218
+ return channels
219
+
220
+ def extract_final_channel_fallback(text: str) -> str:
221
+ """Fallback extraction when harmony library isn't available."""
222
+ # Look for the final channel marker
223
  final_marker = "<|channel|>final<|message|>"
224
+
225
  if final_marker in text:
226
  parts = text.split(final_marker)
227
  if len(parts) > 1:
228
  final_text = parts[-1]
229
+
230
+ # Clean up end markers
231
+ end_markers = ["<|return|>", "<|end|>", "<|endoftext|>"]
232
+ for marker in end_markers:
233
  if marker in final_text:
234
  final_text = final_text.split(marker)[0]
235
+
236
  return final_text.strip()
237
 
238
+ # If no channel markers found, return cleaned text
239
  return text.strip()
240
 
241
  # -----------------------
242
+ # Rose guidance
243
  # -----------------------
244
+
245
+ def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
246
+ """Create vocab bias from {token: weight}."""
247
+ vocab_size = len(tokenizer)
248
+ bias = torch.zeros(vocab_size, dtype=torch.float32)
249
+ for tok, w in mapping.items():
250
+ if tok is None:
251
+ continue
252
+ tid = tokenizer.convert_tokens_to_ids(tok)
253
+ if isinstance(tid, list):
254
+ for t in tid:
255
+ if isinstance(t, int) and t >= 0:
256
+ bias[t] += float(w) / max(1, len(tid))
257
+ elif isinstance(tid, int) and tid >= 0:
258
+ bias[tid] += float(w)
259
+ return bias
260
+
261
+ class RoseGuidedLogits(torch.nn.Module):
262
+ def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0):
263
+ super().__init__()
264
+ self.bias_vec = bias_vec
265
+ self.alpha = float(alpha)
266
+
267
+ def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
268
+ return scores + self.alpha * self.bias_vec.to(scores.device)
269
+
270
+ @spaces.GPU(duration=120)
271
+ def zerogpu_generate(full_prompt: str,
272
+ gen_kwargs: Dict[str, Any],
273
+ rose_map: Optional[Dict[str, float]],
274
+ rose_alpha: float,
275
+ rose_score: Optional[float],
276
+ seed: Optional[int]) -> Dict[str, str]:
277
+ """Run inference on GPU and return parsed channels."""
278
  try:
279
+ if seed is not None:
280
+ torch.manual_seed(int(seed))
281
+
282
+ # Load model
283
+ model = _load_model_on("auto")
284
 
285
+ # Setup logits processor for Rose guidance
286
+ logits_processor = None
287
+ if rose_map:
288
+ bias = build_bias_from_tokens(tokenizer, rose_map).to(next(model.parameters()).device)
289
+ eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
290
+ logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
291
+
292
+ # Tokenize input
293
+ inputs = tokenizer(full_prompt, return_tensors="pt").to(next(model.parameters()).device)
294
 
295
  # Generate
296
+ out_ids = model.generate(
297
+ **inputs,
298
+ do_sample=bool(gen_kwargs.get("do_sample", True)),
299
+ temperature=float(gen_kwargs.get("temperature", 0.7)),
300
+ top_p=float(gen_kwargs.get("top_p", 0.9)),
301
+ top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
302
+ max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
303
+ pad_token_id=tokenizer.eos_token_id,
304
+ eos_token_id=tokenizer.eos_token_id,
305
+ logits_processor=logits_processor,
306
+ )
307
 
308
+ # Extract generated tokens only
309
+ prompt_len = int(inputs["input_ids"].shape[1])
310
+ gen_ids = out_ids[0][prompt_len:].tolist()
 
311
 
312
+ # Parse response with Harmony
313
+ if HARMONY_AVAILABLE:
314
+ channels = parse_harmony_response(gen_ids)
315
+ else:
316
+ # Fallback
317
+ decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
318
+ channels = {
319
+ "final": extract_final_channel_fallback(decoded),
320
+ "raw": decoded
321
+ }
322
 
323
+ return channels
324
+
325
  except Exception as e:
326
+ return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": str(e)}
 
 
327
  finally:
328
  # Cleanup
329
+ try:
330
+ del model
331
+ except:
332
+ pass
333
+ gc.collect()
334
  if torch.cuda.is_available():
335
  torch.cuda.empty_cache()
 
 
 
 
 
336
 
337
  # -----------------------
338
+ # Gradio handlers
339
  # -----------------------
340
+
341
+ def generate_response(message: str, history: List[List[str]], system_prompt: str,
342
+ temperature: float, top_p: float, top_k: int, max_new_tokens: int,
343
+ do_sample: bool, seed: Optional[int],
344
+ rose_enable: bool, rose_alpha: float, rose_score: Optional[float],
345
+ rose_tokens: str, rose_json: str,
346
+ show_thinking: bool = False,
347
+ reasoning_effort: str = "high") -> str:
348
+ """
349
+ Generate response with proper CoT handling using Harmony format.
350
+ """
 
351
  try:
352
+ # Build message list
353
+ messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
354
 
355
+ # Add history
356
+ if history:
357
+ for turn in history:
358
+ if isinstance(turn, (list, tuple)) and len(turn) >= 2:
359
+ user_msg, assistant_msg = turn[0], turn[1]
360
+ if user_msg:
361
+ messages.append({"role": "user", "content": str(user_msg)})
362
+ if assistant_msg:
363
+ messages.append({"role": "assistant", "content": str(assistant_msg)})
364
 
365
+ # Add current message
366
+ messages.append({"role": "user", "content": str(message)})
367
 
368
+ # Create Harmony-formatted prompt
369
+ if HARMONY_AVAILABLE:
370
+ prompt = create_harmony_prompt(messages, reasoning_effort)
371
+ else:
372
+ # Fallback to tokenizer template
373
+ prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
374
+
375
+ # Build Rose map if enabled
376
+ rose_map: Optional[Dict[str, float]] = None
377
+ if rose_enable:
378
+ rose_map = {}
379
+ tok_str = (rose_tokens or "").strip()
380
+ if tok_str:
381
+ for p in [p.strip() for p in tok_str.split(",") if p.strip()]:
382
+ if ":" in p:
383
+ k, v = p.split(":", 1)
384
+ try:
385
+ rose_map[k.strip()] = float(v)
386
+ except:
387
+ pass
388
+ if rose_json:
389
+ try:
390
+ j = json.loads(rose_json)
391
+ if isinstance(j, dict):
392
+ for k, v in j.items():
393
+ try:
394
+ rose_map[str(k)] = float(v)
395
+ except:
396
+ pass
397
+ except:
398
+ pass
399
+ if not rose_map:
400
+ rose_map = None
401
+
402
+ # Generate with model
403
+ channels = zerogpu_generate(
404
+ prompt,
405
+ {
406
+ "do_sample": bool(do_sample),
407
+ "temperature": float(temperature),
408
+ "top_p": float(top_p),
409
+ "top_k": int(top_k) if top_k > 0 else None,
410
+ "max_new_tokens": int(max_new_tokens),
411
+ },
412
+ rose_map,
413
+ float(rose_alpha),
414
+ float(rose_score) if rose_score is not None else None,
415
+ int(seed) if seed is not None else None,
416
  )
417
 
418
+ # Format response
419
  if show_thinking:
420
+ # Show all channels
421
+ response = "## Chain of Thought:\n\n"
422
+ for channel, content in channels.items():
423
+ if channel != "final" and content:
424
+ response += f"### {channel.capitalize()} Channel:\n{content}\n\n"
425
+ response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
426
+ return response
427
  else:
428
+ # Just show the final response
429
+ return channels.get("final", "No final response generated")
430
 
431
  except Exception as e:
432
+ return f"[Error] {type(e).__name__}: {str(e)}"
 
 
433
 
434
  # -----------------------
435
+ # UI
436
  # -----------------------
437
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
438
+ gr.Markdown(
439
+ """
440
+ # Mirel ��� Harmony Chain-of-Thought Inference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
441
 
442
+ OSS-20B model using Harmony format with thinking channels.
443
+ The model thinks through problems in internal channels before providing a final response.
 
444
 
445
+ **Note:** Install `openai-harmony` for full Harmony support: `pip install openai-harmony`
446
+ """
447
+ )
448
+
449
+ with gr.Row():
450
+ system_prompt = gr.Textbox(
451
+ label="System Prompt",
452
+ value=SYSTEM_DEF,
453
+ lines=2
454
+ )
455
+
456
+ with gr.Accordion("Generation Settings", open=False):
457
+ with gr.Row():
458
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
459
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
460
+ top_k = gr.Slider(0, 200, value=0, step=1, label="Top-k (0=disabled)")
461
+ with gr.Row():
462
+ max_new = gr.Slider(16, 4096, value=MAX_DEF, step=16, label="Max new tokens")
463
+ do_sample = gr.Checkbox(value=True, label="Do sample")
464
+ seed = gr.Number(value=None, label="Seed (optional)", precision=0)
465
+ with gr.Row():
466
+ reasoning_effort = gr.Radio(
467
+ choices=["low", "medium", "high"],
468
+ value="high",
469
+ label="Reasoning Effort",
470
+ info="How much thinking the model should do"
471
  )
472
+ show_thinking = gr.Checkbox(
473
+ value=False,
474
+ label="Show thinking channels",
475
+ info="Display all internal reasoning channels"
476
+ )
477
+
478
+ with gr.Accordion("Rose Guidance (Optional)", open=False):
479
+ gr.Markdown("Fine-tune generation with token biases")
480
+ with gr.Row():
481
+ rose_enable = gr.Checkbox(value=False, label="Enable Rose bias")
482
+ rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="Alpha (strength)")
483
+ rose_score = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Score multiplier")
484
+ rose_tokens = gr.Textbox(
485
+ label="Token:weight pairs",
486
+ placeholder="example:1.5, test:-0.5",
487
+ value=""
488
  )
489
+ rose_json = gr.Textbox(
490
+ label="JSON weights",
491
+ placeholder='{"token": 1.0, "another": -0.5}',
492
+ value=""
 
 
 
 
 
 
493
  )
494
+
495
+ # Chat interface - using only valid parameters
496
+ chat = gr.ChatInterface(
497
+ fn=generate_response,
498
+ additional_inputs=[
499
+ system_prompt, temperature, top_p, top_k, max_new,
500
+ do_sample, seed, rose_enable, rose_alpha, rose_score,
501
+ rose_tokens, rose_json, show_thinking, reasoning_effort
502
+ ],
503
+ title="Chat with Mirel",
504
+ description="A chain-of-thought model using Harmony format",
505
+ examples=[
506
+ ["Hello! Can you introduce yourself?"],
507
+ ["What is the capital of France?"],
508
+ ["Explain quantum computing in simple terms"],
509
+ ["Solve: If a train travels 120 miles in 2 hours, what is its average speed?"],
510
+ ],
511
+ cache_examples=False,
512
+ )
513
+
514
+ gr.Markdown(
515
+ """
516
+ ---
517
+ ### Configuration:
518
+ - **Model**: Set `MODEL_ID` env var (default: openai/gpt-oss-20b)
519
+ - **Adapter**: Set `ADAPTER_ID` and optionally `ADAPTER_SUBFOLDER`
520
+ - **Auth**: Set `HF_TOKEN` in Space secrets for private model access
521
+ - **Harmony**: Install with `pip install openai-harmony` for proper channel support
522
 
523
+ The model uses Harmony format with thinking channels (`thinking`, `analysis`, `final`).
524
+ """
525
+ )
526
 
 
 
 
527
  if __name__ == "__main__":
528
+ demo.queue(max_size=8 if ZEROGPU else 32).launch(
529
+ server_name="0.0.0.0",
 
 
530
  server_port=7860,
531
+ share=True
532
  )