|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import time |
|
import spaces |
|
import re |
|
|
|
|
|
MODELS = { |
|
"20B": "openai/gpt-oss-20b", |
|
"120B": "openai/gpt-oss-20b", |
|
} |
|
|
|
@spaces.GPU |
|
def generate_response(model_id, conversation, user_message, max_length=512, temperature=0.7): |
|
"""Generate response using ZeroGPU - all CUDA operations happen here""" |
|
print(f"π Loading {model_id}...") |
|
start_time = time.time() |
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
if tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
load_time = time.time() - start_time |
|
print(f"β
Model loaded in {load_time:.2f}s") |
|
|
|
|
|
messages = [] |
|
system_prompt = ( |
|
"You are GPT, a helpful, harmless, and honest AI assistant. " |
|
"You provide clear, accurate, and concise responses to user questions. " |
|
"You are knowledgeable across many domains and always aim to be respectful and helpful. " |
|
) |
|
messages.append({"role": "system", "content": system_prompt}) |
|
|
|
|
|
for msg in conversation: |
|
messages.append(msg) |
|
|
|
|
|
messages.append({"role": "user", "content": user_message}) |
|
|
|
prompt = tokenizer.apply_chat_template( |
|
messages, |
|
tokenize=False, |
|
add_generation_prompt=True |
|
) |
|
inputs = tokenizer(prompt, return_tensors="pt") |
|
device = next(model.parameters()).device |
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
generation_start = time.time() |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
top_p=0.9, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
generation_time = time.time() - generation_start |
|
response = tokenizer.decode( |
|
outputs[0][inputs['input_ids'].shape[-1]:], |
|
skip_special_tokens=True |
|
).strip() |
|
print(f"Generation time: {generation_time:.2f}s") |
|
return response, load_time, generation_time |
|
|
|
def format_response_with_thinking(response): |
|
"""Format response to handle <think></think> tags""" |
|
|
|
if '<think>' in response and '</think>' in response: |
|
|
|
pattern = r'(.*?)(<think>(.*?)</think>)(.*)' |
|
match = re.search(pattern, response, re.DOTALL) |
|
|
|
if match: |
|
before_thinking = match.group(1).strip() |
|
thinking_content = match.group(3).strip() |
|
after_thinking = match.group(4).strip() |
|
|
|
|
|
html = f"{before_thinking}\n" |
|
html += f'<div class="thinking-container">' |
|
html += f'<button class="thinking-toggle" onclick="this.nextElementSibling.classList.toggle(\'hidden\'); this.textContent = this.textContent === \'Show reasoning\' ? \'Hide reasoning\' : \'Show reasoning\'">Show reasoning</button>' |
|
html += f'<div class="thinking-content hidden">{thinking_content}</div>' |
|
html += f'</div>\n' |
|
html += after_thinking |
|
|
|
return html |
|
|
|
|
|
return response |
|
|
|
def chat_submit(message, history, conversation_state, model_name, max_length, temperature): |
|
"""Process a new message and update the chat history""" |
|
if not message.strip(): |
|
return "", history, conversation_state |
|
|
|
model_id = MODELS.get(model_name, MODELS["20B"]) |
|
try: |
|
|
|
print(f"Processing message: {message}") |
|
print(f"Selected model: {model_name} ({model_id})") |
|
|
|
response, load_time, generation_time = generate_response( |
|
model_id, conversation_state, message, max_length, temperature |
|
) |
|
|
|
|
|
conversation_state.append({"role": "user", "content": message}) |
|
conversation_state.append({"role": "assistant", "content": response}) |
|
|
|
|
|
formatted_response = format_response_with_thinking(response) |
|
|
|
|
|
history.append((message, formatted_response)) |
|
print(f"Response added to history. Current length: {len(history)}") |
|
|
|
return "", history, conversation_state |
|
except Exception as e: |
|
import traceback |
|
print(f"Error in chat_submit: {str(e)}") |
|
print(traceback.format_exc()) |
|
error_message = f"Error: {str(e)}" |
|
history.append((message, error_message)) |
|
return "", history, conversation_state |
|
|
|
css = """ |
|
.message { |
|
padding: 10px; |
|
margin: 5px; |
|
border-radius: 10px; |
|
} |
|
.thinking-container { |
|
margin: 10px 0; |
|
} |
|
.thinking-toggle { |
|
background-color: #f1f1f1; |
|
border: 1px solid #ddd; |
|
border-radius: 4px; |
|
padding: 5px 10px; |
|
cursor: pointer; |
|
font-size: 0.9em; |
|
margin-bottom: 5px; |
|
color: #555; |
|
} |
|
.thinking-content { |
|
background-color: #f9f9f9; |
|
border-left: 3px solid #ccc; |
|
padding: 10px; |
|
margin-top: 5px; |
|
font-size: 0.95em; |
|
color: #555; |
|
font-family: monospace; |
|
white-space: pre-wrap; |
|
overflow-x: auto; |
|
} |
|
.hidden { |
|
display: none; |
|
} |
|
""" |
|
|
|
with gr.Blocks(title="GPT-OSS Playground Chat", css=css) as demo: |
|
gr.Markdown("# π GPT-OSS Playground Chat") |
|
gr.Markdown("*Powered by HuggingFace ZeroGPU*") |
|
|
|
|
|
conversation_state = gr.State([]) |
|
|
|
chatbot = gr.Chatbot(height=500, label="Athena", render_markdown=True) |
|
|
|
with gr.Row(): |
|
user_input = gr.Textbox(label="Your message", scale=8, autofocus=True, placeholder="Type your message here...") |
|
send_btn = gr.Button(value="Send", scale=1, variant="primary") |
|
|
|
|
|
clear_btn = gr.Button("Clear Conversation") |
|
|
|
|
|
gr.Markdown("### βοΈ Model & Generation Settings") |
|
with gr.Row(): |
|
model_choice = gr.Dropdown( |
|
label="π± Model", |
|
choices=list(MODELS.keys()), |
|
value="20B", |
|
info="Select which Athena model to use" |
|
) |
|
max_length = gr.Slider( |
|
32, 8192, value=512, |
|
label="π Max Tokens", |
|
info="Maximum number of tokens to generate" |
|
) |
|
temperature = gr.Slider( |
|
0.1, 2.0, value=0.7, |
|
label="π¨ Creativity", |
|
info="Higher values = more creative responses" |
|
) |
|
|
|
|
|
def clear_conversation(): |
|
return [], [] |
|
|
|
|
|
user_input.submit( |
|
chat_submit, |
|
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], |
|
outputs=[user_input, chatbot, conversation_state] |
|
) |
|
|
|
|
|
send_btn.click( |
|
chat_submit, |
|
inputs=[user_input, chatbot, conversation_state, model_choice, max_length, temperature], |
|
outputs=[user_input, chatbot, conversation_state] |
|
) |
|
|
|
|
|
clear_btn.click(clear_conversation, outputs=[chatbot, conversation_state]) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"What is artificial intelligence?", |
|
"Can you explain quantum computing?", |
|
"Write a short poem about technology", |
|
"What are some ethical concerns about AI?" |
|
], |
|
inputs=[user_input] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch(debug=True) |