File size: 5,381 Bytes
5eb7c5e
6ea0840
5eb7c5e
 
881f4c4
5eb7c5e
6ea0840
 
 
5eb7c5e
 
 
 
6ea0840
5eb7c5e
 
 
6ea0840
 
 
5eb7c5e
 
 
6ea0840
 
5eb7c5e
 
 
 
 
 
 
 
 
 
 
 
 
 
8858101
6ea0840
f6ab47e
6ea0840
5eb7c5e
 
85c8ad0
8858101
5eb7c5e
 
 
6ea0840
 
 
8858101
 
f56e17f
5eb7c5e
53639d5
704f6d2
53639d5
6ea0840
 
b576940
 
 
 
 
 
 
6ea0840
 
b576940
6ea0840
 
b576940
 
f56e17f
 
 
 
 
 
 
 
 
 
6ea0840
b576940
6ea0840
5eb7c5e
 
 
 
 
 
 
881f4c4
 
 
 
 
8858101
881f4c4
5eb7c5e
8858101
5eb7c5e
 
 
 
 
 
 
 
 
 
 
 
8858101
 
 
 
5eb7c5e
8858101
5eb7c5e
8858101
 
f56e17f
8858101
85c8ad0
5eb7c5e
85c8ad0
5eb7c5e
 
 
881f4c4
 
5eb7c5e
 
 
 
881f4c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
import threading
import time

# Global variables to store the model and tokenizer
model = None
tokenizer = None
model_loading_lock = threading.Lock()
model_loaded = False  # Status flag to indicate if the model is loaded

def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
    global model, tokenizer, model_loaded
    with model_loading_lock:
        if not model_loaded:
            print("Loading model...")
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                device_map="sequential",
                torch_dtype=torch.float16,
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                offload_folder="offload"
            )
            model_loaded = True
            print("Model loaded successfully.")
        else:
            print("Model already loaded.")

def check_model_status():
    """Check if the model is loaded and reload if necessary."""
    global model_loaded
    if not model_loaded:
        print("Model not loaded. Reloading...")
        load_model()
    return model_loaded

def chat(message, history, temperature, max_new_tokens):
    global model, tokenizer
    stop_tokens = ["|im_end|"]
    
    # Ensure the model is loaded before proceeding
    if not check_model_status():
        yield "Model is not ready. Please try again later.", ""
        return
    
    prompt = f"Human: {message}\n\nAssistant:"
    
    # Tokenize the input
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Stream the response
    start_time = time.time()
    token_count = 0
    
    # Create a TextStreamer for token streaming
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    generate_kwargs = dict(
        input_ids=inputs.input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
        streamer=streamer  # Use the TextStreamer here
    )
    
    # Create and start the thread with the model.generate function
    t = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    
    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        token_count += 1
        
        # Calculate tokens per second
        elapsed_time = time.time() - start_time
        tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
        
        # Update the token status
        token_status_value = f"Tokens Generated: {token_count}, Tokens/Second: {tokens_per_second:.2f}"
        yield "".join(outputs), token_status_value
        
        if any(stop_token in new_token for stop_token in stop_tokens):
            break

def reload_model_button():
    """Reload the model manually via a button."""
    global model_loaded
    model_loaded = False
    load_model()
    return "Model reloaded successfully."

# Function to periodically update the status text
def update_status_periodically(status_text):
    while True:
        time.sleep(5)  # Update every 5 seconds
        status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
        status_text.value = status  # Update the value directly

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# DeepSeek-R1 Chatbot")
    gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ๋Œ€ํ™” ํ…Œ์ŠคํŠธ์šฉ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.")
    
    with gr.Row():
        chatbot = gr.Chatbot(height=600)
        textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
    
    with gr.Row():
        send_button = gr.Button("Send")
        clear_button = gr.Button("Clear")
        reload_button = gr.Button("Reload Model")
    
    with gr.Row():
        temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
        max_tokens_slider = gr.Slider(minimum=32, maximum=2048, value=2048, step=32, label="Max New Tokens")
    
    status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
    token_status = gr.Textbox(label="Token Generation Status", value="", interactive=False)
    
    def respond(message, chat_history, temperature, max_new_tokens):
        bot_message = ""
        for partial_response, token_status_value in chat(message, chat_history, temperature, max_new_tokens):
            bot_message = partial_response
            yield "", chat_history + [(message, bot_message)], gr.update(value=token_status_value)
    
    send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot, token_status])
    clear_button.click(lambda: [], None, chatbot)
    reload_button.click(reload_model_button, None, status_text)
    
    # Start a background thread to update the status text periodically
    threading.Thread(target=update_status_periodically, args=(status_text,), daemon=True).start()

# Load the model when the server starts
if __name__ == "__main__":
    load_model()  # Pre-load the model
    demo.launch(server_name="0.0.0.0")