AbstractPhil
peft loading fixed
3248cf5
raw
history blame
32.2 kB
"""
Mirel Harmony Inference – HF Space (Gradio)
ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B
Proper LoRA adapter loading and conversion for MX compatibility
Single file: app.py
"""
from __future__ import annotations
import os, gc, json, threading, torch, warnings
from dataclasses import dataclass
from typing import List, Dict, Optional, Any, Union
from datetime import datetime
import gradio as gr
import spaces # required for ZeroGPU
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import numpy as np
# Suppress warnings about MX format
warnings.filterwarnings("ignore", message=".*microscaling.*")
warnings.filterwarnings("ignore", message=".*mx.*")
# Import Harmony components
try:
from openai_harmony import (
Author,
Conversation,
HarmonyEncodingName,
Message,
Role,
SystemContent,
DeveloperContent,
load_harmony_encoding,
ReasoningEffort
)
HARMONY_AVAILABLE = True
except ImportError:
print("[WARNING] openai_harmony not installed. Install with: pip install openai-harmony")
HARMONY_AVAILABLE = False
# -----------------------
# Config & runtime modes
# -----------------------
# MX format uses special dtypes - we need to handle this properly
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b")
ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") # Default to your adapter
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") # Default to the subfolder
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager")
SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.")
MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256"))
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1"
# For GPT-OSS models, we need specific handling
IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower()
USE_MX_FORMAT = os.getenv("USE_MX_FORMAT", "1" if IS_GPT_OSS else "0") == "1"
# Harmony channels for CoT
REQUIRED_CHANNELS = ["analysis", "commentary", "final"]
# HF Auth
HF_TOKEN: Optional[str] = (
os.getenv("HF_TOKEN")
or os.getenv("HUGGING_FACE_HUB_TOKEN")
or os.getenv("HUGGINGFACEHUB_API_TOKEN")
or os.getenv("HF_ACCESS_TOKEN")
)
def _hf_login() -> None:
"""Login to HF Hub using common env secret names."""
if HF_TOKEN:
try:
from huggingface_hub import login, whoami
login(token=HF_TOKEN, add_to_git_credential=True)
try:
who = whoami(token=HF_TOKEN)
print(f"[HF Auth] Logged in as: {who.get('name') or who.get('fullname') or who.get('id', 'unknown')}")
except Exception:
print("[HF Auth] Login successful but couldn't get user info")
except Exception as e:
print(f"[HF Auth] Login failed: {e}")
else:
print("[HF Auth] No token found in environment variables")
# Login before loading any models
_hf_login()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Load Harmony encoding if available
if HARMONY_AVAILABLE:
harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
else:
harmony_encoding = None
# Stop tokens per Harmony spec: <|return|> (200002), <|call|> (200012)
HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions() if HARMONY_AVAILABLE else []
# Tokenizer is lightweight; load once
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
print(f"[Model] Successfully loaded tokenizer from {MODEL_ID}")
except Exception as e:
print(f"[Model] Failed to load tokenizer: {e}")
raise
# -----------------------
# PEFT and MX Format Support
# -----------------------
try:
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model
_HAS_PEFT = True
except Exception:
_HAS_PEFT = False
print("[Warning] PEFT not available. Install with: pip install peft")
# Try to import microscaling support if available
try:
import msamp
_HAS_MSAMP = True
print("[Info] Microsoft AMP (msamp) available for MX format support")
except ImportError:
_HAS_MSAMP = False
print("[Info] msamp not available - using fallback MX handling")
# -----------------------
# MX Format Conversion
# -----------------------
def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert fp32 LoRA weights to be compatible with MX format base model.
MX models expect specific dtype handling.
"""
converted = {}
for key, tensor in lora_state_dict.items():
if tensor is None:
converted[key] = tensor
continue
# LoRA weights (lora_A, lora_B) need special handling
if 'lora_' in key:
# For MX compatibility, we keep weights in fp32 but ensure proper scaling
# MX format internally handles quantization, we just need clean fp32 inputs
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
# Ensure weights are in reasonable range for MX quantization
# MX format works best with weights in [-1, 1] range
if 'lora_A' in key:
# Input projection - initialize with small values
std = 1.0 / torch.sqrt(torch.tensor(tensor.shape[1], dtype=torch.float32))
if tensor.std() > std * 10: # If weights are too large
print(f"[MX Convert] Scaling down {key} from std={tensor.std():.4f} to {std:.4f}")
tensor = tensor * (std / tensor.std())
elif 'lora_B' in key:
# Output projection - should be near zero initially
if tensor.abs().max() > 0.1:
print(f"[MX Convert] Scaling down {key} max={tensor.abs().max():.4f}")
tensor = tensor * 0.01
converted[key] = tensor
else:
# Non-LoRA weights (like embeddings) stay as-is
converted[key] = tensor
return converted
def prepare_model_for_mx_lora(model, adapter_path: str, subfolder: Optional[str] = None):
"""
Prepare and attach LoRA adapter to MX format model.
Handles the special requirements of GPT-OSS MX models.
"""
if not _HAS_PEFT:
raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft")
# Build the full path including subfolder
full_adapter_path = adapter_path
if subfolder:
print(f"[LoRA] Loading adapter from {adapter_path} (subfolder: {subfolder})")
else:
print(f"[LoRA] Loading adapter from {adapter_path}")
# Load the LoRA config with subfolder support
peft_kwargs = {"token": HF_TOKEN}
if subfolder:
peft_kwargs["subfolder"] = subfolder
peft_config = PeftConfig.from_pretrained(adapter_path, **peft_kwargs)
# Load the LoRA weights - need to check in the right location
from safetensors.torch import load_file
import os.path as osp
from huggingface_hub import hf_hub_download
try:
# Try to download from HF Hub with subfolder
if subfolder:
# Download the adapter weights file
try:
adapter_weights_path = hf_hub_download(
repo_id=adapter_path,
filename="adapter_model.safetensors",
subfolder=subfolder,
token=HF_TOKEN
)
adapter_weights = load_file(adapter_weights_path)
print(f"[LoRA] Loaded safetensors weights from {subfolder}")
except Exception:
# Try .bin format
adapter_weights_path = hf_hub_download(
repo_id=adapter_path,
filename="adapter_model.bin",
subfolder=subfolder,
token=HF_TOKEN
)
adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
print(f"[LoRA] Loaded bin weights from {subfolder}")
else:
# No subfolder - try local path first, then HF Hub
local_safetensors = osp.join(adapter_path, "adapter_model.safetensors")
local_bin = osp.join(adapter_path, "adapter_model.bin")
if osp.exists(local_safetensors):
adapter_weights = load_file(local_safetensors)
print("[LoRA] Loaded local safetensors weights")
elif osp.exists(local_bin):
adapter_weights = torch.load(local_bin, map_location="cpu")
print("[LoRA] Loaded local bin weights")
else:
# Try downloading from HF Hub
try:
adapter_weights_path = hf_hub_download(
repo_id=adapter_path,
filename="adapter_model.safetensors",
token=HF_TOKEN
)
adapter_weights = load_file(adapter_weights_path)
print("[LoRA] Downloaded safetensors weights from Hub")
except Exception:
adapter_weights_path = hf_hub_download(
repo_id=adapter_path,
filename="adapter_model.bin",
token=HF_TOKEN
)
adapter_weights = torch.load(adapter_weights_path, map_location="cpu")
print("[LoRA] Downloaded bin weights from Hub")
except Exception as e:
raise FileNotFoundError(f"Could not load adapter weights: {e}")
# Convert weights for MX compatibility
print("[LoRA] Converting fp32 weights for MX format compatibility...")
adapter_weights = convert_fp32_lora_to_mx_compatible(adapter_weights)
# Create PEFT model with special handling for MX
print("[LoRA] Attaching LoRA to base model...")
# For MX models, we need to be careful about dtype
# The base model uses MX format internally, but the interface should be fp32
model = PeftModel.from_pretrained(
model,
adapter_path,
is_trainable=False,
**peft_kwargs # This includes token and subfolder
)
# Manually update the adapter weights with our converted versions
model.load_state_dict(adapter_weights, strict=False)
print("[LoRA] Successfully attached LoRA adapter with MX compatibility")
return model
# -----------------------
# Model loading with MX support
# -----------------------
def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]:
"""Build kwargs for model loading with MX format support."""
kw: Dict[str, Any] = dict(
device_map=device_map,
trust_remote_code=True,
low_cpu_mem_usage=True,
token=HF_TOKEN,
)
if IS_GPT_OSS and USE_MX_FORMAT:
# GPT-OSS models use MX format
# Don't specify torch_dtype - let the model use its native MX format
print("[Model] Using MX format for GPT-OSS model")
kw.update({
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
# MX models handle their own dtype internally
# Don't force a dtype here
})
else:
# Non-MX models
kw.update({
"torch_dtype": torch.float16, # Use fp16 for non-MX models
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager",
})
return kw
def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM:
"""Load model with proper MX format handling."""
print(f"[Model] Loading base model from {MODEL_ID}...")
# Load config first to check for MX format
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN)
# Check if this is an MX model
is_mx_model = (
IS_GPT_OSS or
hasattr(config, 'quantization_config') and 'mx' in str(config.quantization_config).lower() or
hasattr(config, 'torch_dtype') and 'mx' in str(config.torch_dtype).lower()
)
if is_mx_model:
print("[Model] Detected MX format model - using special loading")
# For MX models, we need special handling
# The model internally uses MX quantization
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
config=config,
trust_remote_code=True,
device_map=device_map,
low_cpu_mem_usage=True,
token=HF_TOKEN,
# Let the model handle its own dtype
attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager",
)
# Verify the model loaded correctly
print(f"[Model] Model dtype: {next(model.parameters()).dtype}")
print(f"[Model] Model device: {next(model.parameters()).device}")
else:
# Standard model loading
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map))
# Load and attach LoRA adapter if specified
if ADAPTER_ID:
try:
if is_mx_model:
# Use special MX-compatible LoRA loading with subfolder support
model = prepare_model_for_mx_lora(model, ADAPTER_ID, ADAPTER_SUBFOLDER)
else:
# Standard PEFT loading for non-MX models
if not _HAS_PEFT:
raise RuntimeError("PEFT is required when ADAPTER_ID is set.")
print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...")
peft_kwargs = {"token": HF_TOKEN, "is_trainable": False}
if ADAPTER_SUBFOLDER:
peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER
print(f"[Model] Using subfolder: {ADAPTER_SUBFOLDER}")
model = PeftModel.from_pretrained(
model,
ADAPTER_ID,
**peft_kwargs
)
print("[Model] Successfully loaded with LoRA adapter")
# Optionally merge adapter for better performance
merge_adapter = os.getenv("MERGE_ADAPTER", "0") == "1"
if merge_adapter and hasattr(model, 'merge_and_unload'):
print("[Model] Merging adapter into base model...")
model = model.merge_and_unload()
print("[Model] Adapter merged successfully")
except Exception as e:
print(f"[Error] Failed to load adapter: {e}")
print("[Warning] Continuing with base model only")
model.eval()
# Ensure proper config
if getattr(model.config, "pad_token_id", None) is None:
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
model.config.use_cache = True
print(f"[Model] Model loaded successfully - Type: {'MX Format' if is_mx_model else 'Standard'}")
return model
# -----------------------
# Harmony formatting
# -----------------------
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any:
"""Build a Harmony-formatted prompt."""
if HARMONY_AVAILABLE and harmony_encoding is not None:
effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH}
effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH)
system_content = (
SystemContent.new()
.with_model_identity("You are ChatGPT, a large language model trained by OpenAI.")
.with_reasoning_effort(effort)
.with_conversation_start_date(datetime.now().strftime("%Y-%m-%d"))
.with_knowledge_cutoff("2024-06")
.with_required_channels(REQUIRED_CHANNELS)
)
sys_text = SYSTEM_DEF
rest: List[Dict[str, str]] = messages or []
if rest and rest[0].get("role") == "system":
sys_text = rest[0].get("content") or SYSTEM_DEF
rest = rest[1:]
harmony_messages = [Message.from_role_and_content(Role.SYSTEM, system_content)]
dev = DeveloperContent.new().with_instructions(sys_text)
harmony_messages.append(Message.from_role_and_content(Role.DEVELOPER, dev))
for m in rest:
role = m.get("role"); content = m.get("content", "")
if role == "user":
harmony_messages.append(Message.from_role_and_content(Role.USER, content))
elif role == "assistant":
harmony_messages.append(
Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final")
)
convo = Conversation.from_messages(harmony_messages)
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
# Fallback: tokenizer chat template
if not messages or messages[0].get("role") != "system":
messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or [])
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]:
"""Parse response tokens using Harmony format to extract channels."""
if not HARMONY_AVAILABLE:
text = tokenizer.decode(tokens, skip_special_tokens=False)
return {"final": extract_final_channel_fallback(text), "raw": text}
parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT)
channels = {}
for msg in parsed_messages:
channel = msg.channel if hasattr(msg, 'channel') else "final"
if channel not in channels:
channels[channel] = ""
channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])])
if "final" not in channels:
channels["final"] = " ".join(channels.values())
return channels
def extract_final_channel_fallback(text: str) -> str:
"""Extract the <final> channel from decoded Harmony text."""
try:
chunks: Dict[str, str] = {}
pieces = text.split("<|channel|>")
for seg in pieces[1:]:
name_end = seg.find("<|message|>")
if name_end <= 0:
continue
ch = seg[:name_end].strip()
body_start = name_end + len("<|message|>")
next_pos = len(seg)
for delim in ("<|channel|>", "<|end|>", "<|return|>"):
p = seg.find(delim, body_start)
if p != -1:
next_pos = min(next_pos, p)
body = seg[body_start:next_pos]
chunks[ch] = chunks.get(ch, "") + body
final_txt = (chunks.get("final", "").strip())
if final_txt:
return final_txt
if "<|channel|>final<|message|>" in text:
tail = text.split("<|channel|>final<|message|>")[-1]
for delim in ("<|return|>", "<|end|>", "<|channel|>"):
idx = tail.find(delim)
if idx != -1:
tail = tail[:idx]
break
return tail.strip()
except Exception:
pass
return text.strip()
# -----------------------
# Rose guidance
# -----------------------
def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor:
"""Create vocab bias from {token: weight}."""
vocab_size = len(tokenizer)
bias = torch.zeros(vocab_size, dtype=torch.float32)
for tok, w in mapping.items():
if tok is None:
continue
tid = tokenizer.convert_tokens_to_ids(tok)
if isinstance(tid, list):
for t in tid:
if isinstance(t, int) and t >= 0:
bias[t] += float(w) / max(1, len(tid))
elif isinstance(tid, int) and tid >= 0:
bias[tid] += float(w)
return bias
class RoseGuidedLogits(torch.nn.Module):
def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0):
super().__init__()
self.bias_vec = bias_vec
self.alpha = float(alpha)
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
return scores + self.alpha * self.bias_vec.to(scores.device)
# -----------------------
# Generation
# -----------------------
@spaces.GPU(duration=120)
def zerogpu_generate(full_prompt,
gen_kwargs: Dict[str, Any],
rose_map: Optional[Dict[str, float]],
rose_alpha: float,
rose_score: Optional[float],
seed: Optional[int]) -> Dict[str, str]:
"""Run inference on GPU with MX format support."""
try:
if seed is not None:
torch.manual_seed(int(seed))
# Load model with MX support
model = _load_model_on("auto")
# Setup logits processor for Rose guidance
logits_processor = None
if rose_map:
bias = build_bias_from_tokens(tokenizer, rose_map).to(next(model.parameters()).device)
eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0)
logits_processor = [RoseGuidedLogits(bias, eff_alpha)]
# Prepare inputs
device = next(model.parameters()).device
if HARMONY_AVAILABLE and isinstance(full_prompt, list):
input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device)
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
prompt_len = input_ids.shape[1]
else:
enc = tokenizer(full_prompt, return_tensors="pt")
inputs = enc.to(device)
prompt_len = int(inputs["input_ids"].shape[1])
if "attention_mask" not in inputs:
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device)
# Generate
eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id
out_ids = model.generate(
**inputs,
do_sample=bool(gen_kwargs.get("do_sample", True)),
temperature=float(gen_kwargs.get("temperature", 0.7)),
top_p=float(gen_kwargs.get("top_p", 0.9)),
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None),
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)),
pad_token_id=model.config.pad_token_id,
eos_token_id=eos_ids,
logits_processor=logits_processor,
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)),
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)),
min_new_tokens=1,
)
# Extract generated tokens
out_list = out_ids[0].tolist()
gen_ids = out_list[prompt_len:]
# Truncate at stop tokens
if HARMONY_AVAILABLE:
for sid in HARMONY_STOP_IDS:
if sid in gen_ids:
gen_ids = gen_ids[:gen_ids.index(sid)]
break
# Parse response
if HARMONY_AVAILABLE:
try:
channels = parse_harmony_response(gen_ids)
except Exception:
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
channels = {
"final": extract_final_channel_fallback(decoded),
"raw": decoded
}
else:
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False)
channels = {
"final": extract_final_channel_fallback(decoded),
"raw": decoded
}
return channels
except Exception as e:
import traceback
error_trace = traceback.format_exc()
print(f"[Error] Generation failed:\n{error_trace}")
return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": error_trace}
finally:
# Cleanup
try:
del model
except:
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# -----------------------
# Gradio handlers
# -----------------------
def generate_response(message: str, history: List[List[str]], system_prompt: str,
temperature: float, top_p: float, top_k: int, max_new_tokens: int,
do_sample: bool, seed: Optional[int],
rose_enable: bool, rose_alpha: float, rose_score: Optional[float],
rose_tokens: str, rose_json: str,
show_thinking: bool = False,
reasoning_effort: str = "high") -> str:
"""Generate response with CoT handling."""
try:
# Build messages
messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}]
if history:
for turn in history:
if isinstance(turn, (list, tuple)) and len(turn) >= 2:
user_msg, assistant_msg = turn[0], turn[1]
if user_msg:
messages.append({"role": "user", "content": str(user_msg)})
if assistant_msg:
messages.append({"role": "assistant", "content": str(assistant_msg)})
messages.append({"role": "user", "content": str(message)})
# Create prompt
if HARMONY_AVAILABLE:
prompt = create_harmony_prompt(messages, reasoning_effort)
else:
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
# Build Rose map
rose_map: Optional[Dict[str, float]] = None
if rose_enable:
rose_map = {}
tok_str = (rose_tokens or "").strip()
if tok_str:
for p in [p.strip() for p in tok_str.split(",") if p.strip()]:
if ":" in p:
k, v = p.split(":", 1)
try:
rose_map[k.strip()] = float(v)
except:
pass
if rose_json:
try:
j = json.loads(rose_json)
if isinstance(j, dict):
for k, v in j.items():
try:
rose_map[str(k)] = float(v)
except:
pass
except:
pass
if not rose_map:
rose_map = None
# Generate
channels = zerogpu_generate(
prompt,
{
"do_sample": bool(do_sample),
"temperature": float(temperature),
"top_p": float(top_p),
"top_k": int(top_k) if top_k > 0 else None,
"max_new_tokens": int(max_new_tokens),
"repetition_penalty": 1.1,
"no_repeat_ngram_size": 6,
},
rose_map,
float(rose_alpha),
float(rose_score) if rose_score is not None else None,
int(seed) if seed is not None else None,
)
# Format response
if show_thinking:
response = "## Chain of Thought:\n\n"
for channel, content in channels.items():
if channel != "final" and content:
response += f"### {channel.capitalize()} Channel:\n{content}\n\n"
response += f"### Final Response:\n{channels.get('final', 'No final response generated')}"
return response
else:
return channels.get("final", "No final response generated")
except Exception as e:
import traceback
return f"[Error] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}"
# -----------------------
# UI
# -----------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
f"""
# Mirel – Harmony Chain-of-Thought Inference
**Model**: {MODEL_ID} {'(MX Format)' if USE_MX_FORMAT else ''}
**Adapter**: {ADAPTER_ID or 'None'}
**Status**: {'✅ Harmony Available' if HARMONY_AVAILABLE else '⚠️ Harmony Not Installed'}
The model uses internal thinking channels before providing final responses.
"""
)
with gr.Row():
system_prompt = gr.Textbox(
label="System Prompt",
value=SYSTEM_DEF,
lines=2
)
with gr.Accordion("Generation Settings", open=False):
with gr.Row():
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p")
top_k = gr.Slider(0, 200, value=0, step=1, label="Top-k (0=disabled)")
with gr.Row():
max_new = gr.Slider(16, 4096, value=MAX_DEF, step=16, label="Max new tokens")
do_sample = gr.Checkbox(value=True, label="Do sample")
seed = gr.Number(value=None, label="Seed (optional)", precision=0)
with gr.Row():
reasoning_effort = gr.Radio(
choices=["low", "medium", "high"],
value="high",
label="Reasoning Effort",
info="How much thinking the model should do"
)
show_thinking = gr.Checkbox(
value=False,
label="Show thinking channels",
info="Display all internal reasoning channels"
)
with gr.Accordion("Rose Guidance (Optional)", open=False):
gr.Markdown("Fine-tune generation with token biases")
with gr.Row():
rose_enable = gr.Checkbox(value=False, label="Enable Rose bias")
rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="Alpha (strength)")
rose_score = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Score multiplier")
rose_tokens = gr.Textbox(
label="Token:weight pairs",
placeholder="example:1.5, test:-0.5",
value=""
)
rose_json = gr.Textbox(
label="JSON weights",
placeholder='{"token": 1.0, "another": -0.5}',
value=""
)
# Chat interface
chat = gr.ChatInterface(
fn=generate_response,
type="messages",
additional_inputs=[
system_prompt, temperature, top_p, top_k, max_new,
do_sample, seed, rose_enable, rose_alpha, rose_score,
rose_tokens, rose_json, show_thinking, reasoning_effort
],
title="Chat with Mirel",
description="Chain-of-thought model with MX format support",
examples=[
["Hello! Can you introduce yourself?"],
["What is the capital of France?"],
["Explain quantum computing in simple terms"],
["Solve: If a train travels 120 miles in 2 hours, what is its average speed?"],
],
cache_examples=False,
)
gr.Markdown(
"""
---
### Configuration:
- **MX Format**: Automatically detected for GPT-OSS models
- **LoRA Support**: fp32 LoRA adapters are converted for MX compatibility
- **Merge Adapter**: Set `MERGE_ADAPTER=1` to merge LoRA into base model
- **Auth**: Set `HF_TOKEN` in Space secrets for private model access
"""
)
if __name__ == "__main__":
demo.queue(max_size=8 if ZEROGPU else 32).launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)