File size: 3,818 Bytes
5eb7c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from transformers import pipeline
import gradio as gr
import threading

# Global variable to store the model pipeline
model_pipeline = None
model_loading_lock = threading.Lock()
model_loaded = False  # Status flag to indicate if the model is loaded

def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
    global model_pipeline, model_loaded
    with model_loading_lock:
        if not model_loaded:
            print("Loading model...")
            pipe = pipeline(
                "text-generation",
                model=model_name,
                device_map="sequential",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                truncation=True,
                max_new_tokens=2048,
                model_kwargs={
                    "low_cpu_mem_usage": True,
                    "offload_folder": "offload"
                }
            )
            model_pipeline = pipe
            model_loaded = True
            print("Model loaded successfully.")
        else:
            print("Model already loaded.")

def check_model_status():
    """Check if the model is loaded and reload if necessary."""
    global model_loaded
    if not model_loaded:
        print("Model not loaded. Reloading...")
        load_model()
    return model_loaded

def chat(message, history):
    global model_pipeline
    
    # Ensure the model is loaded before proceeding
    if not check_model_status():
        return "Model is not ready. Please try again later."
    
    prompt = f"Human: {message}\n\nAssistant:"
    
    # Generate response using the pre-loaded model
    response = model_pipeline(
        prompt,
        max_new_tokens=2048,
        temperature=0.7,
        do_sample=True,
        truncation=True,
        pad_token_id=50256
    )
    
    try:
        bot_text = response[0]["generated_text"]
        bot_text = bot_text.split("Assistant:")[-1].strip()
        if "</think>" in bot_text:
            bot_text = bot_text.split("</think>")[-1].strip()
    except Exception as e:
        bot_text = f"Sorry, there was a problem generating the response: {str(e)}"
    
    return bot_text

def reload_model_button():
    """Reload the model manually via a button."""
    global model_loaded
    model_loaded = False
    load_model()
    return "Model reloaded successfully."

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# DeepSeek-R1 Chatbot")
    gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ๋Œ€ํ™” ํ…Œ์ŠคํŠธ์šฉ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.")
    
    with gr.Row():
        chatbot = gr.Chatbot(height=600)
        textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
    
    with gr.Row():
        send_button = gr.Button("Send")
        clear_button = gr.Button("Clear")
        reload_button = gr.Button("Reload Model")
    
    status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
    
    def update_status():
        """Update the model status text."""
        if model_loaded:
            return "Model is loaded and ready."
        else:
            return "Model is not loaded."
    
    def respond(message, chat_history):
        bot_message = chat(message, chat_history)
        chat_history.append((message, bot_message))
        return "", chat_history
    
    send_button.click(respond, inputs=[textbox, chatbot], outputs=[textbox, chatbot])
    clear_button.click(lambda: [], None, chatbot)
    reload_button.click(reload_model_button, None, status_text)
    
    # Periodically check and update model status
    demo.load(update_status, None, status_text, every=5)

# Load the model when the server starts
if __name__ == "__main__":
    load_model()  # Pre-load the model
    demo.launch(server_name="0.0.0.0")