Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,780 Bytes
757241b a4b21e5 6ce8b1e b14955e b0c4a3f a4b21e5 b14955e b0c4a3f 95e8864 b0c4a3f a4b21e5 b14955e dc6c5d6 a4b21e5 b14955e a4b21e5 ba3e817 757241b a4b21e5 6ce8b1e b0c4a3f 757241b a4b21e5 b0c4a3f ba3e817 b0c4a3f 757241b b0c4a3f b14955e 757241b b14955e 757241b b14955e a4b21e5 14d377a a4b21e5 b14955e 2870fe9 b14955e a4b21e5 5e757ee b0c4a3f b14955e a4b21e5 b14955e a4b21e5 5e757ee a4b21e5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
import torch
from threading import Thread
import gradio as gr
import spaces
import re
import logging
from peft import PeftModel
# ----------------------------------------------------------------------
# KaTeX delimiter config for Gradio
# ----------------------------------------------------------------------
LATEX_DELIMS = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
{"left": "\\(", "right": "\\)", "display": False},
]
# Configure logging
logging.basicConfig(level=logging.INFO)
# Load the base model
try:
base_model = AutoModelForCausalLM.from_pretrained(
"openai/gpt-oss-20b",
torch_dtype="auto",
device_map="auto",
attn_implementation="kernels-community/vllm-flash-attn3"
)
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
# Load the LoRA adapter
try:
model = PeftModel.from_pretrained(base_model, "Tonic/gpt-oss-20b-multilingual-reasoner")
print("โ
LoRA model loaded successfully!")
except Exception as lora_error:
print(f"โ ๏ธ LoRA adapter failed to load: {lora_error}")
print("๐ Falling back to base model...")
model = base_model
except Exception as e:
print(f"โ Error loading model: {e}")
raise e
def format_conversation_history(chat_history):
messages = []
for item in chat_history:
role = item["role"]
content = item["content"]
if isinstance(content, list):
content = content[0]["text"] if content and "text" in content[0] else str(content)
messages.append({"role": role, "content": content})
return messages
def format_analysis_response(text):
"""Enhanced response formatting with better structure and LaTeX support."""
# Look for analysis section followed by final response
m = re.search(r"analysis(.*?)assistantfinal", text, re.DOTALL | re.IGNORECASE)
if m:
reasoning = m.group(1).strip()
response = text.split("assistantfinal", 1)[-1].strip()
# Clean up the reasoning section
reasoning = re.sub(r'^analysis\s*', '', reasoning, flags=re.IGNORECASE).strip()
# Format with improved structure
formatted = (
f"**๐ค Analysis & Reasoning:**\n\n"
f"*{reasoning}*\n\n"
f"---\n\n"
f"**๐ฌ Final Response:**\n\n{response}"
)
# Ensure LaTeX delimiters are balanced
if formatted.count("$") % 2:
formatted += "$"
return formatted
# Fallback: clean up the text and return as-is
cleaned = re.sub(r'^analysis\s*', '', text, flags=re.IGNORECASE).strip()
if cleaned.count("$") % 2:
cleaned += "$"
return cleaned
@spaces.GPU(duration=60)
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
if not input_data.strip():
yield "Please enter a prompt."
return
# Log the request
logging.info(f"[User] {input_data}")
logging.info(f"[System] {system_prompt} | Temp={temperature} | Max tokens={max_new_tokens}")
new_message = {"role": "user", "content": input_data}
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
processed_history = format_conversation_history(chat_history)
messages = system_message + processed_history + [new_message]
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Create streamer for proper streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Prepare generation kwargs
generation_kwargs = {
"max_new_tokens": max_new_tokens,
"do_sample": True,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"repetition_penalty": repetition_penalty,
"pad_token_id": tokenizer.eos_token_id,
"streamer": streamer,
"use_cache": True
}
# Tokenize input using the chat template
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs})
thread.start()
# Stream the response with enhanced formatting
collected_text = ""
buffer = ""
yielded_once = False
try:
for chunk in streamer:
if not chunk:
continue
collected_text += chunk
buffer += chunk
# Initial yield to show immediate response
if not yielded_once:
yield chunk
buffer = ""
yielded_once = True
continue
# Yield accumulated text periodically for smooth streaming
if "\n" in buffer or len(buffer) > 150:
# Use enhanced formatting for partial text
partial_formatted = format_analysis_response(collected_text)
yield partial_formatted
buffer = ""
# Final formatting with complete text
final_formatted = format_analysis_response(collected_text)
yield final_formatted
except Exception as e:
logging.exception("Generation streaming failed")
yield f"โ Error during generation: {e}"
demo = gr.ChatInterface(
fn=generate_response,
additional_inputs=[
gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048),
gr.Textbox(
label="System Prompt",
value="You are a helpful assistant. Reasoning: medium",
lines=4,
placeholder="Change system prompt"
),
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
],
examples=[
[{"text": "Explain Newton's laws clearly and concisely with mathematical formulas"}],
[{"text": "Write a Python function to calculate the Fibonacci sequence"}],
[{"text": "What are the benefits of open weight AI models? Include analysis."}],
[{"text": "Solve this equation: $x^2 + 5x + 6 = 0$"}],
],
cache_examples=False,
type="messages",
description="""
# ๐๐ปโโ๏ธWelcome to ๐Tonic's gpt-oss-20b Multilingual Reasoner Demo !
โจ **Enhanced Features:**
- ๐ง **Advanced Reasoning**: Detailed analysis and step-by-step thinking
- ๐ **LaTeX Support**: Mathematical formulas rendered beautifully (use `$` or `$$`)
- ๐ฏ **Improved Formatting**: Clear separation of reasoning and final responses
- ๐ **Smart Logging**: Better error handling and request tracking
๐ก **Usage Tips:**
- Adjust reasoning level in system prompt (e.g., "Reasoning: high")
- Use LaTeX for math: `$E = mc^2$` or `$$\\int x^2 dx$$`
- Wait a couple of seconds initially for model loading
""",
fill_height=True,
textbox=gr.Textbox(
label="Query Input",
placeholder="Type your prompt (supports LaTeX: $x^2 + y^2 = z^2$)"
),
stop_btn="Stop Generation",
multimodal=False,
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.launch(share=True) |