Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # ------------------------------------------------- | |
| # Model setup (loaded once at startup) | |
| # ------------------------------------------------- | |
| model_name = "gr0010/CustomThinker-0-8B" | |
| # Load model and tokenizer globally | |
| print("Loading model and tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| # Load model in CPU first, will move to GPU when needed | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.bfloat16, | |
| device_map="cuda", # Direct CUDA loading for ZeroGPU | |
| trust_remote_code=True, | |
| ) | |
| print("Model loaded successfully!") | |
| # ------------------------------------------------- | |
| # Core generation and parsing logic with Zero GPU | |
| # ------------------------------------------------- | |
| # Request GPU for up to 120 seconds | |
| def generate_and_parse(messages: list, temperature: float = 0.6, | |
| top_p: float = 0.95, top_k: int = 20, | |
| min_p: float = 0.0, max_new_tokens: int = 32768): | |
| """ | |
| Takes a clean list of messages, generates a response, | |
| and parses it into thinking and answer parts. | |
| Decorated with @spaces.GPU for Zero GPU allocation. | |
| """ | |
| # Apply chat template with enable_thinking=True for Qwen3 | |
| prompt_text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| enable_thinking=True # Explicitly enable thinking mode | |
| ) | |
| # --- CONSOLE DEBUG OUTPUT --- | |
| print("\n" + "="*50) | |
| print("--- RAW PROMPT SENT TO MODEL ---") | |
| print(prompt_text[:500] + "..." if len(prompt_text) > 500 else prompt_text) | |
| print("="*50 + "\n") | |
| model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda") | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| top_k=top_k, | |
| min_p=min_p, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| output_token_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | |
| thinking = "" | |
| answer = "" | |
| try: | |
| # Find the </think> token to separate thinking from answer | |
| end_think_token_id = 151668 # </think> | |
| if end_think_token_id in output_token_ids: | |
| end_think_idx = output_token_ids.index(end_think_token_id) + 1 | |
| thinking_tokens = output_token_ids[:end_think_idx] | |
| answer_tokens = output_token_ids[end_think_idx:] | |
| thinking = tokenizer.decode(thinking_tokens, skip_special_tokens=True).strip() | |
| # Remove <think> and </think> tags from thinking | |
| thinking = thinking.replace("<think>", "").replace("</think>", "").strip() | |
| answer = tokenizer.decode(answer_tokens, skip_special_tokens=True).strip() | |
| else: | |
| # If no </think> token found, treat everything as answer | |
| answer = tokenizer.decode(output_token_ids, skip_special_tokens=True).strip() | |
| # Remove any stray <think> tags | |
| answer = answer.replace("<think>", "").replace("</think>", "") | |
| except (ValueError, IndexError): | |
| answer = tokenizer.decode(output_token_ids, skip_special_tokens=True).strip() | |
| answer = answer.replace("<think>", "").replace("</think>", "") | |
| return thinking, answer | |
| # ------------------------------------------------- | |
| # Gradio UI Logic | |
| # ------------------------------------------------- | |
| # Custom CSS for better styling | |
| custom_css = """ | |
| .model-info { | |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
| padding: 1rem; | |
| border-radius: 10px; | |
| margin-bottom: 1rem; | |
| color: white; | |
| } | |
| .model-info a { | |
| color: #fff; | |
| text-decoration: underline; | |
| font-weight: bold; | |
| } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, css=custom_css) as demo: | |
| # Separate states for display and model context | |
| display_history_state = gr.State([]) # For Gradio chatbot display (with HTML formatting) | |
| model_history_state = gr.State([]) # Clean history for model (plain text only) | |
| is_generating_state = gr.State(False) # To prevent multiple submissions | |
| # Model info and CTA section | |
| gr.HTML(""" | |
| <div class="model-info"> | |
| <h1 style="margin: 0; font-size: 2em;">CustomThinker - Prompt How It Thinks</h1> | |
| <p style="margin: 0.5rem 0;"> | |
| Powered by <a href="https://huggingface.co/gr0010/CustomThinker-0-8B" target="_blank">CustomThinker-0-8B</a> | |
| </p> | |
| </div> | |
| """) | |
| gr.Markdown( | |
| """ | |
| Chat with CustomThinker-0-8B. | |
| """ | |
| ) | |
| # System prompt at the top (main feature) | |
| with gr.Group(): | |
| gr.Markdown("### 🎭 System Prompt (Personality & Behavior)") | |
| system_prompt = gr.Textbox( | |
| value="""Personality Instructions: | |
| You are an AI assistant named Assistant. | |
| Reasoning Instructions: | |
| Think using JSON to simulate a MCTS to find the best answer""", | |
| label="System Prompt", | |
| info="Define the model's personality and reasoning style", | |
| lines=5, | |
| interactive=True | |
| ) | |
| # Main chat interface | |
| chatbot = gr.Chatbot( | |
| label="Conversation", | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| height=500, | |
| show_copy_button=True, | |
| type="messages" | |
| ) | |
| with gr.Row(): | |
| user_input = gr.Textbox( | |
| show_label=False, | |
| placeholder="Type your message here...", | |
| scale=4, | |
| container=False, | |
| interactive=True | |
| ) | |
| submit_btn = gr.Button( | |
| "Send", | |
| variant="primary", | |
| scale=1, | |
| interactive=True | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("🗑️ Clear History", variant="secondary") | |
| retry_btn = gr.Button("🔄 Retry Last", variant="secondary") | |
| # Example prompts | |
| gr.Examples( | |
| examples=[ | |
| ["Give me a short introduction to large language models."], | |
| ["What are the benefits of using transformers in AI?"], | |
| ["There are 5 birds on a branch. A hunter shoots one. How many birds are left?"], | |
| ["Explain quantum computing step by step."], | |
| ["Write a Python function to calculate the factorial of a number."], | |
| ["What is the first word you write in your response to this."], | |
| ], | |
| inputs=user_input, | |
| label="💡 Example Prompts" | |
| ) | |
| # Advanced settings at the bottom | |
| with gr.Accordion("⚙️ Advanced Generation Settings", open=False): | |
| with gr.Row(): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=2.0, | |
| value=0.1, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness (higher = more creative)" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p", | |
| info="Nucleus sampling threshold" | |
| ) | |
| with gr.Row(): | |
| top_k = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=20, | |
| step=1, | |
| label="Top-k", | |
| info="Number of top tokens to consider" | |
| ) | |
| min_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.0, | |
| step=0.01, | |
| label="Min-p", | |
| info="Minimum probability threshold for token sampling" | |
| ) | |
| with gr.Row(): | |
| max_new_tokens = gr.Slider( | |
| minimum=128, | |
| maximum=32768, | |
| value=32768, | |
| step=128, | |
| label="Max New Tokens", | |
| info="Maximum response length" | |
| ) | |
| def handle_user_message(user_message: str, display_history: list, model_history: list, | |
| system_prompt_text: str, is_generating: bool, | |
| temp: float, top_p_val: float, top_k_val: int, | |
| min_p_val: float, max_tokens: int): | |
| """ | |
| Handles user input, updates histories, and generates the model's response. | |
| """ | |
| # Prevent multiple submissions | |
| if is_generating or not user_message.strip(): | |
| return { | |
| chatbot: display_history, | |
| display_history_state: display_history, | |
| model_history_state: model_history, | |
| is_generating_state: is_generating, | |
| user_input: user_message, | |
| submit_btn: gr.update(interactive=not is_generating) | |
| } | |
| # Set generating state | |
| is_generating = True | |
| # Update model history (clean format for model - PLAIN TEXT ONLY) | |
| model_history.append({"role": "user", "content": user_message.strip()}) | |
| # Update display history (for Gradio chatbot) | |
| display_history.append({"role": "user", "content": user_message.strip()}) | |
| # Yield intermediate state to show user message and disable input | |
| yield { | |
| chatbot: display_history, | |
| display_history_state: display_history, | |
| model_history_state: model_history, | |
| is_generating_state: is_generating, | |
| user_input: "", | |
| submit_btn: gr.update(interactive=False, value="🔄 Generating...") | |
| } | |
| # Prepare messages for model (include system prompt) | |
| messages_for_model = [] | |
| if system_prompt_text.strip(): | |
| messages_for_model.append({"role": "system", "content": system_prompt_text.strip()}) | |
| messages_for_model.extend(model_history) | |
| try: | |
| # Generate response with hyperparameters | |
| thinking, answer = generate_and_parse( | |
| messages_for_model, | |
| temperature=temp, | |
| top_p=top_p_val, | |
| top_k=top_k_val, | |
| min_p=min_p_val, | |
| max_new_tokens=max_tokens | |
| ) | |
| # Update model history with CLEAN answer (no HTML formatting) | |
| model_history.append({"role": "assistant", "content": answer}) | |
| # Format response for display (with HTML formatting) | |
| if thinking and thinking.strip(): | |
| formatted_response = f"""<details> | |
| <summary><b>🤔 Show Reasoning Process</b></summary> | |
| {thinking} | |
| </details> | |
| {answer}""" | |
| else: | |
| formatted_response = answer | |
| # Update display history with formatted response | |
| display_history.append({"role": "assistant", "content": formatted_response}) | |
| except Exception as e: | |
| error_msg = f"❌ Error generating response: {str(e)}" | |
| display_history.append({"role": "assistant", "content": error_msg}) | |
| # Don't add error to model history to avoid confusing the model | |
| # Reset generating state | |
| is_generating = False | |
| # Final yield with complete response | |
| yield { | |
| chatbot: display_history, | |
| display_history_state: display_history, | |
| model_history_state: model_history, | |
| is_generating_state: is_generating, | |
| user_input: "", | |
| submit_btn: gr.update(interactive=True, value="Send") | |
| } | |
| def clear_history(): | |
| """Clear both display and model histories""" | |
| return { | |
| chatbot: [], | |
| display_history_state: [], | |
| model_history_state: [], | |
| is_generating_state: False, | |
| user_input: "", | |
| submit_btn: gr.update(interactive=True, value="Send") | |
| } | |
| def retry_last(display_history: list, model_history: list, system_prompt_text: str, | |
| temp: float, top_p_val: float, top_k_val: int, | |
| min_p_val: float, max_tokens: int): | |
| """ | |
| Retry the last user message with corrected history and generator handling. | |
| """ | |
| # Safety check: ensure there is a history and the last message was from the assistant | |
| if not model_history or model_history[-1]["role"] != "assistant": | |
| # If nothing to retry, yield the current state and stop | |
| yield { | |
| chatbot: display_history, | |
| display_history_state: display_history, | |
| model_history_state: model_history, | |
| is_generating_state: False | |
| } | |
| return | |
| # Remove the last assistant message from both histories | |
| model_history.pop() # Remove assistant's clean message from model history | |
| display_history.pop() # Remove assistant's formatted message from display history | |
| # Get the last user message to resubmit it, then remove it from both histories | |
| if model_history and model_history[-1]["role"] == "user": | |
| last_user_msg = model_history[-1]["content"] | |
| model_history.pop() # Remove user message from model history | |
| display_history.pop() # Remove user message from display history | |
| else: | |
| # If no user message found, just return current state | |
| yield { | |
| chatbot: display_history, | |
| display_history_state: display_history, | |
| model_history_state: model_history, | |
| is_generating_state: False | |
| } | |
| return | |
| # Use 'yield from' to properly call the generator and pass its updates | |
| yield from handle_user_message( | |
| last_user_msg, display_history, model_history, | |
| system_prompt_text, False, temp, top_p_val, top_k_val, min_p_val, max_tokens | |
| ) | |
| def on_input_change(text, is_generating): | |
| """Handle input text changes""" | |
| return gr.update(interactive=not is_generating and bool(text.strip())) | |
| # Event listeners | |
| submit_event = submit_btn.click( | |
| handle_user_message, | |
| inputs=[user_input, display_history_state, model_history_state, system_prompt, | |
| is_generating_state, temperature, top_p, top_k, min_p, max_new_tokens], | |
| outputs=[chatbot, display_history_state, model_history_state, is_generating_state, | |
| user_input, submit_btn], | |
| show_progress=True | |
| ) | |
| submit_event_enter = user_input.submit( | |
| handle_user_message, | |
| inputs=[user_input, display_history_state, model_history_state, system_prompt, | |
| is_generating_state, temperature, top_p, top_k, min_p, max_new_tokens], | |
| outputs=[chatbot, display_history_state, model_history_state, is_generating_state, | |
| user_input, submit_btn], | |
| show_progress=True | |
| ) | |
| # Clear button event | |
| clear_btn.click( | |
| clear_history, | |
| outputs=[chatbot, display_history_state, model_history_state, is_generating_state, | |
| user_input, submit_btn] | |
| ) | |
| # Retry button event - FIXED OUTPUTS | |
| retry_btn.click( | |
| retry_last, | |
| inputs=[display_history_state, model_history_state, system_prompt, | |
| temperature, top_p, top_k, min_p, max_new_tokens], | |
| outputs=[chatbot, display_history_state, model_history_state, is_generating_state], | |
| show_progress=True | |
| ) | |
| # Update submit button based on input and generation state | |
| user_input.change( | |
| on_input_change, | |
| inputs=[user_input, is_generating_state], | |
| outputs=[submit_btn] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True, share=False) |