File size: 7,780 Bytes
757241b
a4b21e5
 
 
 
6ce8b1e
b14955e
b0c4a3f
a4b21e5
b14955e
 
 
 
 
 
 
 
 
 
 
 
 
 
b0c4a3f
 
 
 
 
 
95e8864
b0c4a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4b21e5
 
 
 
 
 
 
 
 
 
 
b14955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc6c5d6
a4b21e5
b14955e
 
 
 
 
 
 
 
a4b21e5
 
 
 
ba3e817
 
 
 
 
757241b
 
 
 
 
a4b21e5
 
 
 
 
 
6ce8b1e
b0c4a3f
757241b
 
a4b21e5
b0c4a3f
ba3e817
b0c4a3f
 
757241b
 
 
b0c4a3f
b14955e
 
 
 
757241b
b14955e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757241b
b14955e
 
 
 
 
 
 
a4b21e5
 
 
 
14d377a
a4b21e5
 
 
 
 
 
 
 
 
 
 
 
b14955e
2870fe9
b14955e
 
a4b21e5
 
 
 
5e757ee
b0c4a3f
b14955e
 
 
 
 
 
 
 
 
 
 
a4b21e5
 
 
 
b14955e
a4b21e5
 
 
5e757ee
a4b21e5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, pipeline
import torch
from threading import Thread
import gradio as gr
import spaces
import re
import logging
from peft import PeftModel

# ----------------------------------------------------------------------
# KaTeX delimiter config for Gradio
# ----------------------------------------------------------------------

LATEX_DELIMS = [
    {"left": "$$",  "right": "$$",  "display": True},
    {"left": "$",   "right": "$",   "display": False},
    {"left": "\\[", "right": "\\]", "display": True},
    {"left": "\\(", "right": "\\)", "display": False},
]

# Configure logging
logging.basicConfig(level=logging.INFO)

# Load the base model
try:
    base_model = AutoModelForCausalLM.from_pretrained(
        "openai/gpt-oss-20b",
        torch_dtype="auto",
        device_map="auto",
        attn_implementation="kernels-community/vllm-flash-attn3"
    )
    tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")
    
    # Load the LoRA adapter
    try:
        model = PeftModel.from_pretrained(base_model, "Tonic/gpt-oss-20b-multilingual-reasoner")
        print("โœ… LoRA model loaded successfully!")
    except Exception as lora_error:
        print(f"โš ๏ธ LoRA adapter failed to load: {lora_error}")
        print("๐Ÿ”„ Falling back to base model...")
        model = base_model
        
except Exception as e:
    print(f"โŒ Error loading model: {e}")
    raise e

def format_conversation_history(chat_history):
    messages = []
    for item in chat_history:
        role = item["role"]
        content = item["content"]
        if isinstance(content, list):
            content = content[0]["text"] if content and "text" in content[0] else str(content)
        messages.append({"role": role, "content": content})
    return messages

def format_analysis_response(text):
    """Enhanced response formatting with better structure and LaTeX support."""
    # Look for analysis section followed by final response
    m = re.search(r"analysis(.*?)assistantfinal", text, re.DOTALL | re.IGNORECASE)
    if m:
        reasoning = m.group(1).strip()
        response = text.split("assistantfinal", 1)[-1].strip()
        
        # Clean up the reasoning section
        reasoning = re.sub(r'^analysis\s*', '', reasoning, flags=re.IGNORECASE).strip()
        
        # Format with improved structure
        formatted = (
            f"**๐Ÿค” Analysis & Reasoning:**\n\n"
            f"*{reasoning}*\n\n"
            f"---\n\n"
            f"**๐Ÿ’ฌ Final Response:**\n\n{response}"
        )
        
        # Ensure LaTeX delimiters are balanced
        if formatted.count("$") % 2:
            formatted += "$"
            
        return formatted
    
    # Fallback: clean up the text and return as-is
    cleaned = re.sub(r'^analysis\s*', '', text, flags=re.IGNORECASE).strip()
    if cleaned.count("$") % 2:
        cleaned += "$"
    return cleaned

@spaces.GPU(duration=60)
def generate_response(input_data, chat_history, max_new_tokens, system_prompt, temperature, top_p, top_k, repetition_penalty):
    if not input_data.strip():
        yield "Please enter a prompt."
        return
        
    # Log the request
    logging.info(f"[User] {input_data}")
    logging.info(f"[System] {system_prompt} | Temp={temperature} | Max tokens={max_new_tokens}")
    
    new_message = {"role": "user", "content": input_data}
    system_message = [{"role": "system", "content": system_prompt}] if system_prompt else []
    processed_history = format_conversation_history(chat_history)
    messages = system_message + processed_history + [new_message]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    # Create streamer for proper streaming
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    # Prepare generation kwargs
    generation_kwargs = {
        "max_new_tokens": max_new_tokens,
        "do_sample": True,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "repetition_penalty": repetition_penalty,
        "pad_token_id": tokenizer.eos_token_id,
        "streamer": streamer,
        "use_cache": True
    }
    
    # Tokenize input using the chat template
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs={**inputs, **generation_kwargs})
    thread.start()
    
    # Stream the response with enhanced formatting
    collected_text = ""
    buffer = ""
    yielded_once = False
    
    try:
        for chunk in streamer:
            if not chunk:
                continue
                
            collected_text += chunk
            buffer += chunk
            
            # Initial yield to show immediate response
            if not yielded_once:
                yield chunk
                buffer = ""
                yielded_once = True
                continue
            
            # Yield accumulated text periodically for smooth streaming
            if "\n" in buffer or len(buffer) > 150:
                # Use enhanced formatting for partial text
                partial_formatted = format_analysis_response(collected_text)
                yield partial_formatted
                buffer = ""
        
        # Final formatting with complete text
        final_formatted = format_analysis_response(collected_text)
        yield final_formatted
        
    except Exception as e:
        logging.exception("Generation streaming failed")
        yield f"โŒ Error during generation: {e}"

demo = gr.ChatInterface(
    fn=generate_response,
    additional_inputs=[
        gr.Slider(label="Max new tokens", minimum=64, maximum=4096, step=1, value=2048),
        gr.Textbox(
            label="System Prompt",
            value="You are a helpful assistant. Reasoning: medium",
            lines=4,
            placeholder="Change system prompt"
        ),
        gr.Slider(label="Temperature", minimum=0.1, maximum=2.0, step=0.1, value=0.7),
        gr.Slider(label="Top-p", minimum=0.05, maximum=1.0, step=0.05, value=0.9),
        gr.Slider(label="Top-k", minimum=1, maximum=100, step=1, value=50),
        gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.0)
    ],
    examples=[
        [{"text": "Explain Newton's laws clearly and concisely with mathematical formulas"}],
        [{"text": "Write a Python function to calculate the Fibonacci sequence"}],
        [{"text": "What are the benefits of open weight AI models? Include analysis."}],
        [{"text": "Solve this equation: $x^2 + 5x + 6 = 0$"}],
    ],
    cache_examples=False,
    type="messages",
    description="""

# ๐Ÿ™‹๐Ÿปโ€โ™‚๏ธWelcome to ๐ŸŒŸTonic's gpt-oss-20b Multilingual Reasoner Demo !

โœจ **Enhanced Features:**
- ๐Ÿง  **Advanced Reasoning**: Detailed analysis and step-by-step thinking
- ๐Ÿ“Š **LaTeX Support**: Mathematical formulas rendered beautifully (use `$` or `$$`)
- ๐ŸŽฏ **Improved Formatting**: Clear separation of reasoning and final responses
- ๐Ÿ“ **Smart Logging**: Better error handling and request tracking

๐Ÿ’ก **Usage Tips:**
- Adjust reasoning level in system prompt (e.g., "Reasoning: high")
- Use LaTeX for math: `$E = mc^2$` or `$$\\int x^2 dx$$`
- Wait a couple of seconds initially for model loading
    """,
    fill_height=True,
    textbox=gr.Textbox(
        label="Query Input",
        placeholder="Type your prompt (supports LaTeX: $x^2 + y^2 = z^2$)"
    ),
    stop_btn="Stop Generation",
    multimodal=False,
    theme=gr.themes.Soft()
)

if __name__ == "__main__":
    demo.launch(share=True)