File size: 3,980 Bytes
5eb7c5e
 
 
 
881f4c4
5eb7c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881f4c4
 
 
 
 
 
 
5eb7c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881f4c4
 
5eb7c5e
 
 
 
881f4c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import torch
from transformers import pipeline
import gradio as gr
import threading
import time

# Global variable to store the model pipeline
model_pipeline = None
model_loading_lock = threading.Lock()
model_loaded = False  # Status flag to indicate if the model is loaded

def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
    global model_pipeline, model_loaded
    with model_loading_lock:
        if not model_loaded:
            print("Loading model...")
            pipe = pipeline(
                "text-generation",
                model=model_name,
                device_map="sequential",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                truncation=True,
                max_new_tokens=2048,
                model_kwargs={
                    "low_cpu_mem_usage": True,
                    "offload_folder": "offload"
                }
            )
            model_pipeline = pipe
            model_loaded = True
            print("Model loaded successfully.")
        else:
            print("Model already loaded.")

def check_model_status():
    """Check if the model is loaded and reload if necessary."""
    global model_loaded
    if not model_loaded:
        print("Model not loaded. Reloading...")
        load_model()
    return model_loaded

def chat(message, history):
    global model_pipeline
    
    # Ensure the model is loaded before proceeding
    if not check_model_status():
        return "Model is not ready. Please try again later."
    
    prompt = f"Human: {message}\n\nAssistant:"
    
    # Generate response using the pre-loaded model
    response = model_pipeline(
        prompt,
        max_new_tokens=2048,
        temperature=0.7,
        do_sample=True,
        truncation=True,
        pad_token_id=50256
    )
    
    try:
        bot_text = response[0]["generated_text"]
        bot_text = bot_text.split("Assistant:")[-1].strip()
        if "</think>" in bot_text:
            bot_text = bot_text.split("</think>")[-1].strip()
    except Exception as e:
        bot_text = f"Sorry, there was a problem generating the response: {str(e)}"
    
    return bot_text

def reload_model_button():
    """Reload the model manually via a button."""
    global model_loaded
    model_loaded = False
    load_model()
    return "Model reloaded successfully."

# Function to periodically update the status text
def update_status_periodically(status_text):
    while True:
        time.sleep(5)  # Update every 5 seconds
        status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
        status_text.update(value=status)

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("# DeepSeek-R1 Chatbot")
    gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ๋Œ€ํ™” ํ…Œ์ŠคํŠธ์šฉ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.")
    
    with gr.Row():
        chatbot = gr.Chatbot(height=600)
        textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
    
    with gr.Row():
        send_button = gr.Button("Send")
        clear_button = gr.Button("Clear")
        reload_button = gr.Button("Reload Model")
    
    status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
    
    def respond(message, chat_history):
        bot_message = chat(message, chat_history)
        chat_history.append((message, bot_message))
        return "", chat_history
    
    send_button.click(respond, inputs=[textbox, chatbot], outputs=[textbox, chatbot])
    clear_button.click(lambda: [], None, chatbot)
    reload_button.click(reload_model_button, None, status_text)
    
    # Start a background thread to update the status text periodically
    threading.Thread(target=update_status_periodically, args=(status_text,), daemon=True).start()

# Load the model when the server starts
if __name__ == "__main__":
    load_model()  # Pre-load the model
    demo.launch(server_name="0.0.0.0")