import torch from transformers import pipeline, TextStreamer import gradio as gr import threading import time # Global variable to store the model pipeline model_pipeline = None model_loading_lock = threading.Lock() model_loaded = False # Status flag to indicate if the model is loaded def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): global model_pipeline, model_loaded with model_loading_lock: if not model_loaded: print("Loading model...") pipe = pipeline( "text-generation", model=model_name, device_map="sequential", torch_dtype=torch.float16, trust_remote_code=True, truncation=True, max_new_tokens=2048, model_kwargs={ "low_cpu_mem_usage": True, "offload_folder": "offload" } ) model_pipeline = pipe model_loaded = True print("Model loaded successfully.") else: print("Model already loaded.") def check_model_status(): """Check if the model is loaded and reload if necessary.""" global model_loaded if not model_loaded: print("Model not loaded. Reloading...") load_model() return model_loaded def chat(message, history, temperature, max_new_tokens): global model_pipeline # Ensure the model is loaded before proceeding if not check_model_status(): yield "Model is not ready. Please try again later." return prompt = f"Human: {message}\n\nAssistant:" # Stream the response start_time = time.time() generated_tokens = 0 # Create a TextStreamer for token streaming tokenizer = model_pipeline.tokenizer streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) def generate(): stop_tokens = ["<|endoftext|>", "<|im_end|>","|im_end|"] nonlocal generated_tokens response = model_pipeline( prompt, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, truncation=True, pad_token_id=tokenizer.eos_token_id, streamer=streamer # Use the TextStreamer here ) for new_token in streamer: outputs.append(new_token) if new_token in stop_tokens: break yield "".join(outputs), "not implemented" def reload_model_button(): """Reload the model manually via a button.""" global model_loaded model_loaded = False load_model() return "Model reloaded successfully." # Function to periodically update the status text def update_status_periodically(status_text): while True: time.sleep(5) # Update every 5 seconds status = "Model is loaded and ready." if model_loaded else "Model is not loaded." status_text.value = status # Update the value directly # Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# DeepSeek-R1 Chatbot") gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B 모델을 사용한 대화 테스트용 데모입니다.") with gr.Row(): chatbot = gr.Chatbot(height=600) textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7) with gr.Row(): send_button = gr.Button("Send") clear_button = gr.Button("Clear") reload_button = gr.Button("Reload Model") with gr.Row(): temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature") max_tokens_slider = gr.Slider(minimum=32, maximum=2048, value=2048, step=32, label="Max New Tokens") status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False) token_status = gr.Textbox(label="Token Generation Status", value="", interactive=False) def respond(message, chat_history, temperature, max_new_tokens): bot_message = "" status = "" for partial_response, partial_status in chat(message, chat_history, temperature, max_new_tokens): bot_message = partial_response status = partial_status token_status.update(value=status) yield "", chat_history + [(message, bot_message)] send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot]) clear_button.click(lambda: [], None, chatbot) reload_button.click(reload_model_button, None, status_text) # Start a background thread to update the status text periodically threading.Thread(target=update_status_periodically, args=(status_text,), daemon=True).start() # Load the model when the server starts if __name__ == "__main__": load_model() # Pre-load the model demo.launch(server_name="0.0.0.0")