Spaces:
Running
on
Zero
Running
on
Zero
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline | |
import torch | |
from threading import Thread | |
import gradio as gr | |
import spaces | |
import re | |
import logging | |
from peft import PeftModel | |
# ---------------------------------------------------------------------- | |
# KaTeX delimiter config for Gradio | |
# ---------------------------------------------------------------------- | |
LATEX_DELIMS = [ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
{"left": "\\[", "right": "\\]", "display": True}, | |
{"left": "\\(", "right": "\\)", "display": False}, | |
] | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
# Load the base model | |
try: | |
base_model = AutoModelForCausalLM.from_pretrained( | |
"openai/gpt-oss-20b", | |
torch_dtype="auto", | |
device_map="auto", | |
attn_implementation="kernels-community/vllm-flash-attn3" | |
) | |
tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") | |
# Load the LoRA adapter | |
try: | |
model = PeftModel.from_pretrained(base_model, "Tonic/gpt-oss-20b-multilingual-reasoner") | |
print("โ LoRA model loaded successfully!") | |
except Exception as lora_error: | |
print(f"โ ๏ธ LoRA adapter failed to load: {lora_error}") | |
print("๐ Falling back to base model...") | |
model = base_model | |
except Exception as e: | |
print(f"โ Error loading model: {e}") | |
raise e | |
def format_conversation_history(chat_history): | |
messages = [] | |
for item in chat_history: | |
role = item["role"] | |
content = item["content"] | |
if isinstance(content, list): | |
content = content[0]["text"] if content and "text" in content[0] else str(content) | |
messages.append({"role": role, "content": content}) | |
return messages | |
def format_analysis_response(text): | |
"""Enhanced response formatting with better structure and LaTeX support.""" | |
# Look for analysis section followed by final response | |
m = re.search(r"analysis(.*?)assistantfinal", text, re.DOTALL | re.IGNORECASE) | |
if m: | |
reasoning = m.group(1).strip() | |
response = text.split("assistantfinal", 1)[-1].strip() | |
# Clean up the reasoning section | |
reasoning = re.sub(r'^analysis\s*', '', reasoning, flags=re.IGNORECASE).strip() | |
# Format with improved structure | |
formatted = ( | |
f"**๐ค Analysis & Reasoning:**\n\n" | |
f"*{reasoning}*\n\n" | |
f"---\n\n" | |
f"**๐ฌ Final Response:**\n\n{response}" | |
) | |
# Ensure LaTeX delimiters are balanced | |
if formatted.count("$") % 2: | |
formatted += "$" | |
return formatted | |
# Fallback: clean up the text and return as-is | |
cleaned = re.sub(r'^analysis\s*', '', text, flags=re.IGNORECASE).strip() | |
if cleaned.count("$") % 2: | |
cleaned += "$" | |
return cleaned | |
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): | |
if not input_data.strip(): | |
yield "Please enter a prompt." | |
return | |
# Log the request | |
logging.info(f"[User] {input_data}") | |
logging.info(f"[System] {system_prompt} | Temp={temperature} | Max tokens={max_new_tokens}") | |
new_message = {"role": "user", "content": input_data} | |
system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] | |
processed_history = format_conversation_history(chat_history) | |
messages = system_message + processed_history + [new_message] | |
prompt = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Create streamer for proper streaming | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
# Prepare generation kwargs | |
generation_kwargs = { | |
"max_new_tokens": max_new_tokens, | |
"do_sample": True, | |
"temperature": temperature, | |
"top_p": top_p, | |
"top_k": top_k, | |
"repetition_penalty": repetition_penalty, | |
"pad_token_id": tokenizer.eos_token_id, | |
"streamer": streamer, | |
"use_cache": True | |
} | |
# Tokenize input using the chat template | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Start generation in a separate thread | |
thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs}) | |
thread.start() | |
# Stream the response with enhanced formatting | |
collected_text = "" | |
buffer = "" | |
yielded_once = False | |
try: | |
for chunk in streamer: | |
if not chunk: | |
continue | |
collected_text += chunk | |
buffer += chunk | |
# Initial yield to show immediate response | |
if not yielded_once: | |
yield chunk | |
buffer = "" | |
yielded_once = True | |
continue | |
# Yield accumulated text periodically for smooth streaming | |
if "\n" in buffer or len(buffer) > 150: | |
# Use enhanced formatting for partial text | |
partial_formatted = format_analysis_response(collected_text) | |
yield partial_formatted | |
buffer = "" | |
# Final formatting with complete text | |
final_formatted = format_analysis_response(collected_text) | |
yield final_formatted | |
except Exception as e: | |
logging.exception("Generation streaming failed") | |
yield f"โ Error during generation: {e}" | |
demo = gr.ChatInterface( | |
fn=generate_response, | |
additional_inputs=[ | |
gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048), | |
gr.Textbox( | |
label="System Prompt", | |
value="You are a helpful assistant. Reasoning: medium", | |
lines=4, | |
placeholder="Change system prompt" | |
), | |
gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7), | |
gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9), | |
gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50), | |
gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0) | |
], | |
examples=[ | |
[{"text": "Explain Newton's laws clearly and concisely with mathematical formulas"}], | |
[{"text": "Write a Python function to calculate the Fibonacci sequence"}], | |
[{"text": "What are the benefits of open weight AI models? Include analysis."}], | |
[{"text": "Solve this equation: $x^2 + 5x + 6 = 0$"}], | |
], | |
cache_examples=False, | |
type="messages", | |
description=""" | |
# ๐๐ปโโ๏ธWelcome to ๐Tonic's gpt-oss-20b Multilingual Reasoner Demo ! | |
โจ **Enhanced Features:** | |
- ๐ง **Advanced Reasoning**: Detailed analysis and step-by-step thinking | |
- ๐ **LaTeX Support**: Mathematical formulas rendered beautifully (use `$` or `$$`) | |
- ๐ฏ **Improved Formatting**: Clear separation of reasoning and final responses | |
- ๐ **Smart Logging**: Better error handling and request tracking | |
๐ก **Usage Tips:** | |
- Adjust reasoning level in system prompt (e.g., "Reasoning: high") | |
- Use LaTeX for math: `$E = mc^2$` or `$$\\int x^2 dx$$` | |
- Wait a couple of seconds initially for model loading | |
""", | |
fill_height=True, | |
textbox=gr.Textbox( | |
label="Query Input", | |
placeholder="Type your prompt (supports LaTeX: $x^2 + y^2 = z^2$)" | |
), | |
stop_btn="Stop Generation", | |
multimodal=False, | |
theme=gr.themes.Soft() | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) |