File size: 5,040 Bytes
5eb7c5e
53639d5
5eb7c5e
 
881f4c4
5eb7c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8858101
5eb7c5e
b576940
 
5eb7c5e
 
8858101
 
5eb7c5e
 
 
8858101
 
5eb7c5e
53639d5
 
 
 
b576940
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5eb7c5e
 
 
 
 
 
 
881f4c4
 
 
 
 
8858101
881f4c4
5eb7c5e
8858101
5eb7c5e
 
 
 
 
 
 
 
 
 
 
 
8858101
 
 
 
5eb7c5e
8858101
5eb7c5e
8858101
 
 
 
 
 
 
 
5eb7c5e
8858101
5eb7c5e
 
 
881f4c4
 
5eb7c5e
 
 
 
881f4c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
from transformers import pipeline, TextStreamer
import gradio as gr
import threading
import time

# Global variable to store the model pipeline
model_pipeline = None
model_loading_lock = threading.Lock()
model_loaded = False  # Status flag to indicate if the model is loaded

def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
    global model_pipeline, model_loaded
    with model_loading_lock:
        if not model_loaded:
            print("Loading model...")
            pipe = pipeline(
                "text-generation",
                model=model_name,
                device_map="sequential",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                truncation=True,
                max_new_tokens=2048,
                model_kwargs={
                    "low_cpu_mem_usage": True,
                    "offload_folder": "offload"
                }
            )
            model_pipeline = pipe
            model_loaded = True
            print("Model loaded successfully.")
        else:
            print("Model already loaded.")

def check_model_status():
    """Check if the model is loaded and reload if necessary."""
    global model_loaded
    if not model_loaded:
        print("Model not loaded. Reloading...")
        load_model()
    return model_loaded

def chat(message, history, temperature, max_new_tokens):
    global model_pipeline
    stop_tokens = ["<|endoftext|>", "<|im_end|>","|im_end|"]

    # Ensure the model is loaded before proceeding
    if not check_model_status():
        yield "Model is not ready. Please try again later."
        return
    
    prompt = f"Human: {message}\n\nAssistant:"
    
    # Stream the response
    start_time = time.time()
    
    # Create a TextStreamer for token streaming
    tokenizer = model_pipeline.tokenizer
    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    


    pipeline_kwargs = dict(
        prompt=prompt,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        truncation=True,
        pad_token_id=tokenizer.eos_token_id,
        streamer=streamer  # Use the TextStreamer here
    )
    
    # Create and start the thread with the model_pipeline function
    t = threading.Thread(target=lambda: model_pipeline(**pipeline_kwargs))
    t.start()

    for new_token in streamer:
        print(new_token)
        outputs.append(new_token)
        if new_token in stop_tokens:
            
            break
        yield "".join(outputs), "not implemented"
def reload_model_button():
    """Reload the model manually via a button."""
    global model_loaded
    model_loaded = False
    load_model()
    return "Model reloaded successfully."

# Function to periodically update the status text
def update_status_periodically(status_text):
    while True:
        time.sleep(5)  # Update every 5 seconds
        status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
        status_text.value = status  # Update the value directly

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# DeepSeek-R1 Chatbot")
    gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ๋Œ€ํ™” ํ…Œ์ŠคํŠธ์šฉ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.")
    
    with gr.Row():
        chatbot = gr.Chatbot(height=600)
        textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
    
    with gr.Row():
        send_button = gr.Button("Send")
        clear_button = gr.Button("Clear")
        reload_button = gr.Button("Reload Model")
    
    with gr.Row():
        temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
        max_tokens_slider = gr.Slider(minimum=32, maximum=2048, value=2048, step=32, label="Max New Tokens")
    
    status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
    token_status = gr.Textbox(label="Token Generation Status", value="", interactive=False)
    
    def respond(message, chat_history, temperature, max_new_tokens):
        bot_message = ""
        status = ""
        for partial_response, partial_status in chat(message, chat_history, temperature, max_new_tokens):
            bot_message = partial_response
            status = partial_status
            token_status.update(value=status)
            yield "", chat_history + [(message, bot_message)]
    
    send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot])
    clear_button.click(lambda: [], None, chatbot)
    reload_button.click(reload_model_button, None, status_text)
    
    # Start a background thread to update the status text periodically
    threading.Thread(target=update_status_periodically, args=(status_text,), daemon=True).start()

# Load the model when the server starts
if __name__ == "__main__":
    load_model()  # Pre-load the model
    demo.launch(server_name="0.0.0.0")