from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline import torch from threading import Thread import gradio as gr import spaces import re import logging from peft import PeftModel # ---------------------------------------------------------------------- # KaTeX delimiter config for Gradio # ---------------------------------------------------------------------- LATEX_DELIMS = [ {"left": "$$", "right": "$$", "display": True}, {"left": "$", "right": "$", "display": False}, {"left": "\\[", "right": "\\]", "display": True}, {"left": "\\(", "right": "\\)", "display": False}, ] # Configure logging logging.basicConfig(level=logging.INFO) # Load the base model try: base_model = AutoModelForCausalLM.from_pretrained( "openai/gpt-oss-20b", torch_dtype="auto", device_map="auto", attn_implementation="kernels-community/vllm-flash-attn3" ) tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") # Load the LoRA adapter try: model = PeftModel.from_pretrained(base_model, "Tonic/gpt-oss-20b-multilingual-reasoner") print("βœ… LoRA model loaded successfully!") except Exception as lora_error: print(f"⚠️ LoRA adapter failed to load: {lora_error}") print("πŸ”„ Falling back to base model...") model = base_model except Exception as e: print(f"❌ Error loading model: {e}") raise e def format_conversation_history(chat_history): messages = [] for item in chat_history: role = item["role"] content = item["content"] if isinstance(content, list): content = content[0]["text"] if content and "text" in content[0] else str(content) messages.append({"role": role, "content": content}) return messages def format_analysis_response(text): """Enhanced response formatting with better structure and LaTeX support.""" # Look for analysis section followed by final response m = re.search(r"analysis(.*?)assistantfinal", text, re.DOTALL | re.IGNORECASE) if m: reasoning = m.group(1).strip() response = text.split("assistantfinal", 1)[-1].strip() # Clean up the reasoning section reasoning = re.sub(r'^analysis\s*', '', reasoning, flags=re.IGNORECASE).strip() # Format with improved structure formatted = ( f"**πŸ€” Analysis & Reasoning:**\n\n" f"*{reasoning}*\n\n" f"---\n\n" f"**πŸ’¬ Final Response:**\n\n{response}" ) # Ensure LaTeX delimiters are balanced if formatted.count("$") % 2: formatted += "$" return formatted # Fallback: clean up the text and return as-is cleaned = re.sub(r'^analysis\s*', '', text, flags=re.IGNORECASE).strip() if cleaned.count("$") % 2: cleaned += "$" return cleaned @spaces.GPU(duration=60) def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty): if not input_data.strip(): yield "Please enter a prompt." return # Log the request logging.info(f"[User] {input_data}") logging.info(f"[System] {system_prompt} | Temp={temperature} | Max tokens={max_new_tokens}") new_message = {"role": "user", "content": input_data} system_message = [{"role": "system", "content": system_prompt}] if system_prompt else [] processed_history = format_conversation_history(chat_history) messages = system_message + processed_history + [new_message] prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) # Create streamer for proper streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) # Prepare generation kwargs generation_kwargs = { "max_new_tokens": max_new_tokens, "do_sample": True, "temperature": temperature, "top_p": top_p, "top_k": top_k, "repetition_penalty": repetition_penalty, "pad_token_id": tokenizer.eos_token_id, "streamer": streamer, "use_cache": True } # Tokenize input using the chat template inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Start generation in a separate thread thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs}) thread.start() # Stream the response with enhanced formatting collected_text = "" buffer = "" yielded_once = False try: for chunk in streamer: if not chunk: continue collected_text += chunk buffer += chunk # Initial yield to show immediate response if not yielded_once: yield chunk buffer = "" yielded_once = True continue # Yield accumulated text periodically for smooth streaming if "\n" in buffer or len(buffer) > 150: # Use enhanced formatting for partial text partial_formatted = format_analysis_response(collected_text) yield partial_formatted buffer = "" # Final formatting with complete text final_formatted = format_analysis_response(collected_text) yield final_formatted except Exception as e: logging.exception("Generation streaming failed") yield f"❌ Error during generation: {e}" demo = gr.ChatInterface( fn=generate_response, additional_inputs=[ gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048), gr.Textbox( label="System Prompt", value="You are a helpful assistant. Reasoning: medium", lines=4, placeholder="Change system prompt" ), gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7), gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9), gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50), gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0) ], examples=[ [{"text": "Explain Newton's laws clearly and concisely with mathematical formulas"}], [{"text": "Write a Python function to calculate the Fibonacci sequence"}], [{"text": "What are the benefits of open weight AI models? Include analysis."}], [{"text": "Solve this equation: $x^2 + 5x + 6 = 0$"}], ], cache_examples=False, type="messages", description=""" # πŸ™‹πŸ»β€β™‚οΈWelcome to 🌟Tonic's gpt-oss-20b Multilingual Reasoner Demo ! ✨ **Enhanced Features:** - 🧠 **Advanced Reasoning**: Detailed analysis and step-by-step thinking - πŸ“Š **LaTeX Support**: Mathematical formulas rendered beautifully (use `$` or `$$`) - 🎯 **Improved Formatting**: Clear separation of reasoning and final responses - πŸ“ **Smart Logging**: Better error handling and request tracking πŸ’‘ **Usage Tips:** - Adjust reasoning level in system prompt (e.g., "Reasoning: high") - Use LaTeX for math: `$E = mc^2$` or `$$\\int x^2 dx$$` - Wait a couple of seconds initially for model loading """, fill_height=True, textbox=gr.Textbox( label="Query Input", placeholder="Type your prompt (supports LaTeX: $x^2 + y^2 = z^2$)" ), stop_btn="Stop Generation", multimodal=False, theme=gr.themes.Soft() ) if __name__ == "__main__": demo.launch(share=True)