import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import gradio as gr import threading import time # Global variables to store the model and tokenizer model = None tokenizer = None model_loading_lock = threading.Lock() model_loaded = False # Status flag to indicate if the model is loaded def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): global model, tokenizer, model_loaded with model_loading_lock: if not model_loaded: print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="sequential", torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True, offload_folder="offload" ) model_loaded = True print("Model loaded successfully.") else: print("Model already loaded.") def check_model_status(): """Check if the model is loaded and reload if necessary.""" global model_loaded if not model_loaded: print("Model not loaded. Reloading...") load_model() return model_loaded def chat(message, history, temperature, max_new_tokens): global model, tokenizer stop_tokens = ["|im_end|"] # Ensure the model is loaded before proceeding if not check_model_status(): yield "Model is not ready. Please try again later.", "" return prompt = f"Human: {message}\n\nAssistant:" # Tokenize the input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Stream the response start_time = time.time() token_count = 0 # Create a TextStreamer for token streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id, streamer=streamer # Use the TextStreamer here ) # Create and start the thread with the model.generate function t = threading.Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for new_token in streamer: outputs.append(new_token) token_count += 1 # Calculate tokens per second elapsed_time = time.time() - start_time tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0 # Update the token status token_status_value = f"Tokens Generated: {token_count}, Tokens/Second: {tokens_per_second:.2f}" yield "".join(outputs), token_status_value if any(stop_token in new_token for stop_token in stop_tokens): break def reload_model_button(): """Reload the model manually via a button.""" global model_loaded model_loaded = False load_model() return "Model reloaded successfully." # Function to periodically update the status text def update_status_periodically(status_text): while True: time.sleep(5) # Update every 5 seconds status = "Model is loaded and ready." if model_loaded else "Model is not loaded." status_text.value = status # Update the value directly # Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# DeepSeek-R1 Chatbot") gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B 모델을 사용한 대화 테스트용 데모입니다.") with gr.Row(): chatbot = gr.Chatbot(height=600) textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7) with gr.Row(): send_button = gr.Button("Send") clear_button = gr.Button("Clear") reload_button = gr.Button("Reload Model") with gr.Row(): temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") max_tokens_slider = gr.Slider(minimum=32, maximum=2048, value=2048, step=32, label="Max New Tokens") status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False) token_status = gr.Textbox(label="Token Generation Status", value="", interactive=False) def respond(message, chat_history, temperature, max_new_tokens): bot_message = "" for partial_response, token_status_value in chat(message, chat_history, temperature, max_new_tokens): bot_message = partial_response yield "", chat_history + [(message, bot_message)], gr.update(value=token_status_value) send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot, token_status]) clear_button.click(lambda: [], None, chatbot) reload_button.click(reload_model_button, None, status_text) # Start a background thread to update the status text periodically threading.Thread(target=update_status_periodically, args=(status_text,), daemon=True).start() # Load the model when the server starts if __name__ == "__main__": load_model() # Pre-load the model demo.launch(server_name="0.0.0.0")