Spaces:
Running
on
Zero
Running
on
Zero
""" | |
Mirel Harmony Inference – HF Space (Gradio) | |
ZeroGPU-ready, Harmony formatting, MX format support for GPT-OSS-20B | |
Proper LoRA adapter loading and conversion for MX compatibility | |
Single file: app.py | |
""" | |
from __future__ import annotations | |
import os, gc, json, threading, torch, warnings | |
from dataclasses import dataclass | |
from typing import List, Dict, Optional, Any, Union | |
from datetime import datetime | |
import gradio as gr | |
import spaces # required for ZeroGPU | |
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |
import numpy as np | |
# Suppress warnings about MX format | |
warnings.filterwarnings("ignore", message=".*microscaling.*") | |
warnings.filterwarnings("ignore", message=".*mx.*") | |
# Import Harmony components | |
try: | |
from openai_harmony import ( | |
Author, | |
Conversation, | |
HarmonyEncodingName, | |
Message, | |
Role, | |
SystemContent, | |
DeveloperContent, | |
load_harmony_encoding, | |
ReasoningEffort | |
) | |
HARMONY_AVAILABLE = True | |
except ImportError: | |
print("[WARNING] openai_harmony not installed. Install with: pip install openai-harmony") | |
HARMONY_AVAILABLE = False | |
# ----------------------- | |
# Config & runtime modes | |
# ----------------------- | |
# MX format uses special dtypes - we need to handle this properly | |
MODEL_ID = os.getenv("MODEL_ID", "openai/gpt-oss-20b") | |
ADAPTER_ID = os.getenv("ADAPTER_ID", "AbstractPhil/mirel-gpt-oss-20b") # Default to your adapter | |
ADAPTER_SUBFOLDER = os.getenv("ADAPTER_SUBFOLDER", "checkpoints/checkpoint-516") # Default to the subfolder | |
ATTN_IMPL = os.getenv("ATTN_IMPL", "eager") | |
SYSTEM_DEF = os.getenv("SYSTEM_PROMPT", "You are Mirel, a memory-stable symbolic assistant.") | |
MAX_DEF = int(os.getenv("MAX_NEW_TOKENS", "256")) | |
ZEROGPU = os.getenv("ZEROGPU", os.getenv("ZERO_GPU", "0")) == "1" | |
# For GPT-OSS models, we need specific handling | |
IS_GPT_OSS = "gpt-oss" in MODEL_ID.lower() | |
USE_MX_FORMAT = os.getenv("USE_MX_FORMAT", "1" if IS_GPT_OSS else "0") == "1" | |
# Harmony channels for CoT | |
REQUIRED_CHANNELS = ["analysis", "commentary", "final"] | |
# HF Auth | |
HF_TOKEN: Optional[str] = ( | |
os.getenv("HF_TOKEN") | |
or os.getenv("HUGGING_FACE_HUB_TOKEN") | |
or os.getenv("HUGGINGFACEHUB_API_TOKEN") | |
or os.getenv("HF_ACCESS_TOKEN") | |
) | |
def _hf_login() -> None: | |
"""Login to HF Hub using common env secret names.""" | |
if HF_TOKEN: | |
try: | |
from huggingface_hub import login, whoami | |
login(token=HF_TOKEN, add_to_git_credential=True) | |
try: | |
who = whoami(token=HF_TOKEN) | |
print(f"[HF Auth] Logged in as: {who.get('name') or who.get('fullname') or who.get('id', 'unknown')}") | |
except Exception: | |
print("[HF Auth] Login successful but couldn't get user info") | |
except Exception as e: | |
print(f"[HF Auth] Login failed: {e}") | |
else: | |
print("[HF Auth] No token found in environment variables") | |
# Login before loading any models | |
_hf_login() | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
# Load Harmony encoding if available | |
if HARMONY_AVAILABLE: | |
harmony_encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS) | |
else: | |
harmony_encoding = None | |
# Stop tokens per Harmony spec: <|return|> (200002), <|call|> (200012) | |
HARMONY_STOP_IDS = harmony_encoding.stop_tokens_for_assistant_actions() if HARMONY_AVAILABLE else [] | |
# Tokenizer is lightweight; load once | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) | |
print(f"[Model] Successfully loaded tokenizer from {MODEL_ID}") | |
except Exception as e: | |
print(f"[Model] Failed to load tokenizer: {e}") | |
raise | |
# ----------------------- | |
# PEFT and MX Format Support | |
# ----------------------- | |
try: | |
from peft import PeftModel, PeftConfig, LoraConfig, get_peft_model | |
_HAS_PEFT = True | |
except Exception: | |
_HAS_PEFT = False | |
print("[Warning] PEFT not available. Install with: pip install peft") | |
# Try to import microscaling support if available | |
try: | |
import msamp | |
_HAS_MSAMP = True | |
print("[Info] Microsoft AMP (msamp) available for MX format support") | |
except ImportError: | |
_HAS_MSAMP = False | |
print("[Info] msamp not available - using fallback MX handling") | |
# ----------------------- | |
# MX Format Conversion | |
# ----------------------- | |
def convert_fp32_lora_to_mx_compatible(lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | |
""" | |
Convert fp32 LoRA weights to be compatible with MX format base model. | |
MX models expect specific dtype handling. | |
""" | |
converted = {} | |
for key, tensor in lora_state_dict.items(): | |
if tensor is None: | |
converted[key] = tensor | |
continue | |
# LoRA weights (lora_A, lora_B) need special handling | |
if 'lora_' in key: | |
# For MX compatibility, we keep weights in fp32 but ensure proper scaling | |
# MX format internally handles quantization, we just need clean fp32 inputs | |
if tensor.dtype != torch.float32: | |
tensor = tensor.to(torch.float32) | |
# Ensure weights are in reasonable range for MX quantization | |
# MX format works best with weights in [-1, 1] range | |
if 'lora_A' in key: | |
# Input projection - initialize with small values | |
std = 1.0 / torch.sqrt(torch.tensor(tensor.shape[1], dtype=torch.float32)) | |
if tensor.std() > std * 10: # If weights are too large | |
print(f"[MX Convert] Scaling down {key} from std={tensor.std():.4f} to {std:.4f}") | |
tensor = tensor * (std / tensor.std()) | |
elif 'lora_B' in key: | |
# Output projection - should be near zero initially | |
if tensor.abs().max() > 0.1: | |
print(f"[MX Convert] Scaling down {key} max={tensor.abs().max():.4f}") | |
tensor = tensor * 0.01 | |
converted[key] = tensor | |
else: | |
# Non-LoRA weights (like embeddings) stay as-is | |
converted[key] = tensor | |
return converted | |
def prepare_model_for_mx_lora(model, adapter_path: str, subfolder: Optional[str] = None): | |
""" | |
Prepare and attach LoRA adapter to MX format model. | |
Handles the special requirements of GPT-OSS MX models. | |
""" | |
if not _HAS_PEFT: | |
raise RuntimeError("PEFT is required for LoRA adapters. Install with: pip install peft") | |
# Build the full path including subfolder | |
full_adapter_path = adapter_path | |
if subfolder: | |
print(f"[LoRA] Loading adapter from {adapter_path} (subfolder: {subfolder})") | |
else: | |
print(f"[LoRA] Loading adapter from {adapter_path}") | |
# Load the LoRA config with subfolder support | |
peft_kwargs = {"token": HF_TOKEN} | |
if subfolder: | |
peft_kwargs["subfolder"] = subfolder | |
peft_config = PeftConfig.from_pretrained(adapter_path, **peft_kwargs) | |
# Load the LoRA weights - need to check in the right location | |
from safetensors.torch import load_file | |
import os.path as osp | |
from huggingface_hub import hf_hub_download | |
try: | |
# Try to download from HF Hub with subfolder | |
if subfolder: | |
# Download the adapter weights file | |
try: | |
adapter_weights_path = hf_hub_download( | |
repo_id=adapter_path, | |
filename="adapter_model.safetensors", | |
subfolder=subfolder, | |
token=HF_TOKEN | |
) | |
adapter_weights = load_file(adapter_weights_path) | |
print(f"[LoRA] Loaded safetensors weights from {subfolder}") | |
except Exception: | |
# Try .bin format | |
adapter_weights_path = hf_hub_download( | |
repo_id=adapter_path, | |
filename="adapter_model.bin", | |
subfolder=subfolder, | |
token=HF_TOKEN | |
) | |
adapter_weights = torch.load(adapter_weights_path, map_location="cpu") | |
print(f"[LoRA] Loaded bin weights from {subfolder}") | |
else: | |
# No subfolder - try local path first, then HF Hub | |
local_safetensors = osp.join(adapter_path, "adapter_model.safetensors") | |
local_bin = osp.join(adapter_path, "adapter_model.bin") | |
if osp.exists(local_safetensors): | |
adapter_weights = load_file(local_safetensors) | |
print("[LoRA] Loaded local safetensors weights") | |
elif osp.exists(local_bin): | |
adapter_weights = torch.load(local_bin, map_location="cpu") | |
print("[LoRA] Loaded local bin weights") | |
else: | |
# Try downloading from HF Hub | |
try: | |
adapter_weights_path = hf_hub_download( | |
repo_id=adapter_path, | |
filename="adapter_model.safetensors", | |
token=HF_TOKEN | |
) | |
adapter_weights = load_file(adapter_weights_path) | |
print("[LoRA] Downloaded safetensors weights from Hub") | |
except Exception: | |
adapter_weights_path = hf_hub_download( | |
repo_id=adapter_path, | |
filename="adapter_model.bin", | |
token=HF_TOKEN | |
) | |
adapter_weights = torch.load(adapter_weights_path, map_location="cpu") | |
print("[LoRA] Downloaded bin weights from Hub") | |
except Exception as e: | |
raise FileNotFoundError(f"Could not load adapter weights: {e}") | |
# Convert weights for MX compatibility | |
print("[LoRA] Converting fp32 weights for MX format compatibility...") | |
adapter_weights = convert_fp32_lora_to_mx_compatible(adapter_weights) | |
# Create PEFT model with special handling for MX | |
print("[LoRA] Attaching LoRA to base model...") | |
# For MX models, we need to be careful about dtype | |
# The base model uses MX format internally, but the interface should be fp32 | |
model = PeftModel.from_pretrained( | |
model, | |
adapter_path, | |
is_trainable=False, | |
**peft_kwargs # This includes token and subfolder | |
) | |
# Manually update the adapter weights with our converted versions | |
model.load_state_dict(adapter_weights, strict=False) | |
print("[LoRA] Successfully attached LoRA adapter with MX compatibility") | |
return model | |
# ----------------------- | |
# Model loading with MX support | |
# ----------------------- | |
def _build_model_kwargs(device_map: Optional[str]) -> Dict[str, Any]: | |
"""Build kwargs for model loading with MX format support.""" | |
kw: Dict[str, Any] = dict( | |
device_map=device_map, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
token=HF_TOKEN, | |
) | |
if IS_GPT_OSS and USE_MX_FORMAT: | |
# GPT-OSS models use MX format | |
# Don't specify torch_dtype - let the model use its native MX format | |
print("[Model] Using MX format for GPT-OSS model") | |
kw.update({ | |
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager", | |
# MX models handle their own dtype internally | |
# Don't force a dtype here | |
}) | |
else: | |
# Non-MX models | |
kw.update({ | |
"torch_dtype": torch.float16, # Use fp16 for non-MX models | |
"attn_implementation": ATTN_IMPL if device_map != "cpu" else "eager", | |
}) | |
return kw | |
def _load_model_on(device_map: Optional[str]) -> AutoModelForCausalLM: | |
"""Load model with proper MX format handling.""" | |
print(f"[Model] Loading base model from {MODEL_ID}...") | |
# Load config first to check for MX format | |
config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, token=HF_TOKEN) | |
# Check if this is an MX model | |
is_mx_model = ( | |
IS_GPT_OSS or | |
hasattr(config, 'quantization_config') and 'mx' in str(config.quantization_config).lower() or | |
hasattr(config, 'torch_dtype') and 'mx' in str(config.torch_dtype).lower() | |
) | |
if is_mx_model: | |
print("[Model] Detected MX format model - using special loading") | |
# For MX models, we need special handling | |
# The model internally uses MX quantization | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
config=config, | |
trust_remote_code=True, | |
device_map=device_map, | |
low_cpu_mem_usage=True, | |
token=HF_TOKEN, | |
# Let the model handle its own dtype | |
attn_implementation=ATTN_IMPL if device_map != "cpu" else "eager", | |
) | |
# Verify the model loaded correctly | |
print(f"[Model] Model dtype: {next(model.parameters()).dtype}") | |
print(f"[Model] Model device: {next(model.parameters()).device}") | |
else: | |
# Standard model loading | |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **_build_model_kwargs(device_map)) | |
# Load and attach LoRA adapter if specified | |
if ADAPTER_ID: | |
try: | |
if is_mx_model: | |
# Use special MX-compatible LoRA loading with subfolder support | |
model = prepare_model_for_mx_lora(model, ADAPTER_ID, ADAPTER_SUBFOLDER) | |
else: | |
# Standard PEFT loading for non-MX models | |
if not _HAS_PEFT: | |
raise RuntimeError("PEFT is required when ADAPTER_ID is set.") | |
print(f"[Model] Loading adapter from {ADAPTER_ID} (standard mode)...") | |
peft_kwargs = {"token": HF_TOKEN, "is_trainable": False} | |
if ADAPTER_SUBFOLDER: | |
peft_kwargs["subfolder"] = ADAPTER_SUBFOLDER | |
print(f"[Model] Using subfolder: {ADAPTER_SUBFOLDER}") | |
model = PeftModel.from_pretrained( | |
model, | |
ADAPTER_ID, | |
**peft_kwargs | |
) | |
print("[Model] Successfully loaded with LoRA adapter") | |
# Optionally merge adapter for better performance | |
merge_adapter = os.getenv("MERGE_ADAPTER", "0") == "1" | |
if merge_adapter and hasattr(model, 'merge_and_unload'): | |
print("[Model] Merging adapter into base model...") | |
model = model.merge_and_unload() | |
print("[Model] Adapter merged successfully") | |
except Exception as e: | |
print(f"[Error] Failed to load adapter: {e}") | |
print("[Warning] Continuing with base model only") | |
model.eval() | |
# Ensure proper config | |
if getattr(model.config, "pad_token_id", None) is None: | |
model.config.pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id | |
model.config.use_cache = True | |
print(f"[Model] Model loaded successfully - Type: {'MX Format' if is_mx_model else 'Standard'}") | |
return model | |
# ----------------------- | |
# Harmony formatting | |
# ----------------------- | |
def create_harmony_prompt(messages: List[Dict[str, str]], reasoning_effort: str = "high") -> Any: | |
"""Build a Harmony-formatted prompt.""" | |
if HARMONY_AVAILABLE and harmony_encoding is not None: | |
effort_map = {"low": ReasoningEffort.LOW, "medium": ReasoningEffort.MEDIUM, "high": ReasoningEffort.HIGH} | |
effort = effort_map.get(str(reasoning_effort).lower(), ReasoningEffort.HIGH) | |
system_content = ( | |
SystemContent.new() | |
.with_model_identity("You are ChatGPT, a large language model trained by OpenAI.") | |
.with_reasoning_effort(effort) | |
.with_conversation_start_date(datetime.now().strftime("%Y-%m-%d")) | |
.with_knowledge_cutoff("2024-06") | |
.with_required_channels(REQUIRED_CHANNELS) | |
) | |
sys_text = SYSTEM_DEF | |
rest: List[Dict[str, str]] = messages or [] | |
if rest and rest[0].get("role") == "system": | |
sys_text = rest[0].get("content") or SYSTEM_DEF | |
rest = rest[1:] | |
harmony_messages = [Message.from_role_and_content(Role.SYSTEM, system_content)] | |
dev = DeveloperContent.new().with_instructions(sys_text) | |
harmony_messages.append(Message.from_role_and_content(Role.DEVELOPER, dev)) | |
for m in rest: | |
role = m.get("role"); content = m.get("content", "") | |
if role == "user": | |
harmony_messages.append(Message.from_role_and_content(Role.USER, content)) | |
elif role == "assistant": | |
harmony_messages.append( | |
Message.from_role_and_content(Role.ASSISTANT, content).with_channel("final") | |
) | |
convo = Conversation.from_messages(harmony_messages) | |
return harmony_encoding.render_conversation_for_completion(convo, Role.ASSISTANT) | |
# Fallback: tokenizer chat template | |
if not messages or messages[0].get("role") != "system": | |
messages = [{"role": "system", "content": SYSTEM_DEF}] + (messages or []) | |
return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
def parse_harmony_response(tokens: List[int]) -> Dict[str, str]: | |
"""Parse response tokens using Harmony format to extract channels.""" | |
if not HARMONY_AVAILABLE: | |
text = tokenizer.decode(tokens, skip_special_tokens=False) | |
return {"final": extract_final_channel_fallback(text), "raw": text} | |
parsed_messages = harmony_encoding.parse_messages_from_completion_tokens(tokens, Role.ASSISTANT) | |
channels = {} | |
for msg in parsed_messages: | |
channel = msg.channel if hasattr(msg, 'channel') else "final" | |
if channel not in channels: | |
channels[channel] = "" | |
channels[channel] += "".join([getattr(part, "text", str(part)) for part in (msg.content if isinstance(msg.content, list) else [msg.content])]) | |
if "final" not in channels: | |
channels["final"] = " ".join(channels.values()) | |
return channels | |
def extract_final_channel_fallback(text: str) -> str: | |
"""Extract the <final> channel from decoded Harmony text.""" | |
try: | |
chunks: Dict[str, str] = {} | |
pieces = text.split("<|channel|>") | |
for seg in pieces[1:]: | |
name_end = seg.find("<|message|>") | |
if name_end <= 0: | |
continue | |
ch = seg[:name_end].strip() | |
body_start = name_end + len("<|message|>") | |
next_pos = len(seg) | |
for delim in ("<|channel|>", "<|end|>", "<|return|>"): | |
p = seg.find(delim, body_start) | |
if p != -1: | |
next_pos = min(next_pos, p) | |
body = seg[body_start:next_pos] | |
chunks[ch] = chunks.get(ch, "") + body | |
final_txt = (chunks.get("final", "").strip()) | |
if final_txt: | |
return final_txt | |
if "<|channel|>final<|message|>" in text: | |
tail = text.split("<|channel|>final<|message|>")[-1] | |
for delim in ("<|return|>", "<|end|>", "<|channel|>"): | |
idx = tail.find(delim) | |
if idx != -1: | |
tail = tail[:idx] | |
break | |
return tail.strip() | |
except Exception: | |
pass | |
return text.strip() | |
# ----------------------- | |
# Rose guidance | |
# ----------------------- | |
def build_bias_from_tokens(tokenizer, mapping: Dict[str, float]) -> torch.Tensor: | |
"""Create vocab bias from {token: weight}.""" | |
vocab_size = len(tokenizer) | |
bias = torch.zeros(vocab_size, dtype=torch.float32) | |
for tok, w in mapping.items(): | |
if tok is None: | |
continue | |
tid = tokenizer.convert_tokens_to_ids(tok) | |
if isinstance(tid, list): | |
for t in tid: | |
if isinstance(t, int) and t >= 0: | |
bias[t] += float(w) / max(1, len(tid)) | |
elif isinstance(tid, int) and tid >= 0: | |
bias[tid] += float(w) | |
return bias | |
class RoseGuidedLogits(torch.nn.Module): | |
def __init__(self, bias_vec: torch.Tensor, alpha: float = 1.0): | |
super().__init__() | |
self.bias_vec = bias_vec | |
self.alpha = float(alpha) | |
def forward(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: | |
return scores + self.alpha * self.bias_vec.to(scores.device) | |
# ----------------------- | |
# Generation | |
# ----------------------- | |
def zerogpu_generate(full_prompt, | |
gen_kwargs: Dict[str, Any], | |
rose_map: Optional[Dict[str, float]], | |
rose_alpha: float, | |
rose_score: Optional[float], | |
seed: Optional[int]) -> Dict[str, str]: | |
"""Run inference on GPU with MX format support.""" | |
try: | |
if seed is not None: | |
torch.manual_seed(int(seed)) | |
# Load model with MX support | |
model = _load_model_on("auto") | |
# Setup logits processor for Rose guidance | |
logits_processor = None | |
if rose_map: | |
bias = build_bias_from_tokens(tokenizer, rose_map).to(next(model.parameters()).device) | |
eff_alpha = float(rose_alpha) * (float(rose_score) if rose_score is not None else 1.0) | |
logits_processor = [RoseGuidedLogits(bias, eff_alpha)] | |
# Prepare inputs | |
device = next(model.parameters()).device | |
if HARMONY_AVAILABLE and isinstance(full_prompt, list): | |
input_ids = torch.tensor([full_prompt], dtype=torch.long, device=device) | |
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) | |
inputs = {"input_ids": input_ids, "attention_mask": attention_mask} | |
prompt_len = input_ids.shape[1] | |
else: | |
enc = tokenizer(full_prompt, return_tensors="pt") | |
inputs = enc.to(device) | |
prompt_len = int(inputs["input_ids"].shape[1]) | |
if "attention_mask" not in inputs: | |
inputs["attention_mask"] = torch.ones_like(inputs["input_ids"], dtype=torch.long, device=device) | |
# Generate | |
eos_ids = HARMONY_STOP_IDS if HARMONY_AVAILABLE else tokenizer.eos_token_id | |
out_ids = model.generate( | |
**inputs, | |
do_sample=bool(gen_kwargs.get("do_sample", True)), | |
temperature=float(gen_kwargs.get("temperature", 0.7)), | |
top_p=float(gen_kwargs.get("top_p", 0.9)), | |
top_k=(int(gen_kwargs.get("top_k")) if gen_kwargs.get("top_k") and int(gen_kwargs.get("top_k")) > 0 else None), | |
max_new_tokens=int(gen_kwargs.get("max_new_tokens", MAX_DEF)), | |
pad_token_id=model.config.pad_token_id, | |
eos_token_id=eos_ids, | |
logits_processor=logits_processor, | |
repetition_penalty=float(gen_kwargs.get("repetition_penalty", 1.1)), | |
no_repeat_ngram_size=int(gen_kwargs.get("no_repeat_ngram_size", 6)), | |
min_new_tokens=1, | |
) | |
# Extract generated tokens | |
out_list = out_ids[0].tolist() | |
gen_ids = out_list[prompt_len:] | |
# Truncate at stop tokens | |
if HARMONY_AVAILABLE: | |
for sid in HARMONY_STOP_IDS: | |
if sid in gen_ids: | |
gen_ids = gen_ids[:gen_ids.index(sid)] | |
break | |
# Parse response | |
if HARMONY_AVAILABLE: | |
try: | |
channels = parse_harmony_response(gen_ids) | |
except Exception: | |
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False) | |
channels = { | |
"final": extract_final_channel_fallback(decoded), | |
"raw": decoded | |
} | |
else: | |
decoded = tokenizer.decode(gen_ids, skip_special_tokens=False) | |
channels = { | |
"final": extract_final_channel_fallback(decoded), | |
"raw": decoded | |
} | |
return channels | |
except Exception as e: | |
import traceback | |
error_trace = traceback.format_exc() | |
print(f"[Error] Generation failed:\n{error_trace}") | |
return {"final": f"[Error] {type(e).__name__}: {str(e)}", "raw": error_trace} | |
finally: | |
# Cleanup | |
try: | |
del model | |
except: | |
pass | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
# ----------------------- | |
# Gradio handlers | |
# ----------------------- | |
def generate_response(message: str, history: List[List[str]], system_prompt: str, | |
temperature: float, top_p: float, top_k: int, max_new_tokens: int, | |
do_sample: bool, seed: Optional[int], | |
rose_enable: bool, rose_alpha: float, rose_score: Optional[float], | |
rose_tokens: str, rose_json: str, | |
show_thinking: bool = False, | |
reasoning_effort: str = "high") -> str: | |
"""Generate response with CoT handling.""" | |
try: | |
# Build messages | |
messages = [{"role": "system", "content": system_prompt or SYSTEM_DEF}] | |
if history: | |
for turn in history: | |
if isinstance(turn, (list, tuple)) and len(turn) >= 2: | |
user_msg, assistant_msg = turn[0], turn[1] | |
if user_msg: | |
messages.append({"role": "user", "content": str(user_msg)}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": str(assistant_msg)}) | |
messages.append({"role": "user", "content": str(message)}) | |
# Create prompt | |
if HARMONY_AVAILABLE: | |
prompt = create_harmony_prompt(messages, reasoning_effort) | |
else: | |
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) | |
# Build Rose map | |
rose_map: Optional[Dict[str, float]] = None | |
if rose_enable: | |
rose_map = {} | |
tok_str = (rose_tokens or "").strip() | |
if tok_str: | |
for p in [p.strip() for p in tok_str.split(",") if p.strip()]: | |
if ":" in p: | |
k, v = p.split(":", 1) | |
try: | |
rose_map[k.strip()] = float(v) | |
except: | |
pass | |
if rose_json: | |
try: | |
j = json.loads(rose_json) | |
if isinstance(j, dict): | |
for k, v in j.items(): | |
try: | |
rose_map[str(k)] = float(v) | |
except: | |
pass | |
except: | |
pass | |
if not rose_map: | |
rose_map = None | |
# Generate | |
channels = zerogpu_generate( | |
prompt, | |
{ | |
"do_sample": bool(do_sample), | |
"temperature": float(temperature), | |
"top_p": float(top_p), | |
"top_k": int(top_k) if top_k > 0 else None, | |
"max_new_tokens": int(max_new_tokens), | |
"repetition_penalty": 1.1, | |
"no_repeat_ngram_size": 6, | |
}, | |
rose_map, | |
float(rose_alpha), | |
float(rose_score) if rose_score is not None else None, | |
int(seed) if seed is not None else None, | |
) | |
# Format response | |
if show_thinking: | |
response = "## Chain of Thought:\n\n" | |
for channel, content in channels.items(): | |
if channel != "final" and content: | |
response += f"### {channel.capitalize()} Channel:\n{content}\n\n" | |
response += f"### Final Response:\n{channels.get('final', 'No final response generated')}" | |
return response | |
else: | |
return channels.get("final", "No final response generated") | |
except Exception as e: | |
import traceback | |
return f"[Error] {type(e).__name__}: {str(e)}\n{traceback.format_exc()}" | |
# ----------------------- | |
# UI | |
# ----------------------- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
f""" | |
# Mirel – Harmony Chain-of-Thought Inference | |
**Model**: {MODEL_ID} {'(MX Format)' if USE_MX_FORMAT else ''} | |
**Adapter**: {ADAPTER_ID or 'None'} | |
**Status**: {'✅ Harmony Available' if HARMONY_AVAILABLE else '⚠️ Harmony Not Installed'} | |
The model uses internal thinking channels before providing final responses. | |
""" | |
) | |
with gr.Row(): | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
value=SYSTEM_DEF, | |
lines=2 | |
) | |
with gr.Accordion("Generation Settings", open=False): | |
with gr.Row(): | |
temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="Temperature") | |
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.01, label="Top-p") | |
top_k = gr.Slider(0, 200, value=0, step=1, label="Top-k (0=disabled)") | |
with gr.Row(): | |
max_new = gr.Slider(16, 4096, value=MAX_DEF, step=16, label="Max new tokens") | |
do_sample = gr.Checkbox(value=True, label="Do sample") | |
seed = gr.Number(value=None, label="Seed (optional)", precision=0) | |
with gr.Row(): | |
reasoning_effort = gr.Radio( | |
choices=["low", "medium", "high"], | |
value="high", | |
label="Reasoning Effort", | |
info="How much thinking the model should do" | |
) | |
show_thinking = gr.Checkbox( | |
value=False, | |
label="Show thinking channels", | |
info="Display all internal reasoning channels" | |
) | |
with gr.Accordion("Rose Guidance (Optional)", open=False): | |
gr.Markdown("Fine-tune generation with token biases") | |
with gr.Row(): | |
rose_enable = gr.Checkbox(value=False, label="Enable Rose bias") | |
rose_alpha = gr.Slider(0.0, 5.0, value=1.0, step=0.05, label="Alpha (strength)") | |
rose_score = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Score multiplier") | |
rose_tokens = gr.Textbox( | |
label="Token:weight pairs", | |
placeholder="example:1.5, test:-0.5", | |
value="" | |
) | |
rose_json = gr.Textbox( | |
label="JSON weights", | |
placeholder='{"token": 1.0, "another": -0.5}', | |
value="" | |
) | |
# Chat interface | |
chat = gr.ChatInterface( | |
fn=generate_response, | |
type="messages", | |
additional_inputs=[ | |
system_prompt, temperature, top_p, top_k, max_new, | |
do_sample, seed, rose_enable, rose_alpha, rose_score, | |
rose_tokens, rose_json, show_thinking, reasoning_effort | |
], | |
title="Chat with Mirel", | |
description="Chain-of-thought model with MX format support", | |
examples=[ | |
["Hello! Can you introduce yourself?"], | |
["What is the capital of France?"], | |
["Explain quantum computing in simple terms"], | |
["Solve: If a train travels 120 miles in 2 hours, what is its average speed?"], | |
], | |
cache_examples=False, | |
) | |
gr.Markdown( | |
""" | |
--- | |
### Configuration: | |
- **MX Format**: Automatically detected for GPT-OSS models | |
- **LoRA Support**: fp32 LoRA adapters are converted for MX compatibility | |
- **Merge Adapter**: Set `MERGE_ADAPTER=1` to merge LoRA into base model | |
- **Auth**: Set `HF_TOKEN` in Space secrets for private model access | |
""" | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=8 if ZEROGPU else 32).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |