sdafd commited on
Commit
5eb7c5e
ยท
verified ยท
1 Parent(s): 11f173a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +117 -0
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import pipeline
3
+ import gradio as gr
4
+ import threading
5
+
6
+ # Global variable to store the model pipeline
7
+ model_pipeline = None
8
+ model_loading_lock = threading.Lock()
9
+ model_loaded = False # Status flag to indicate if the model is loaded
10
+
11
+ def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
12
+ global model_pipeline, model_loaded
13
+ with model_loading_lock:
14
+ if not model_loaded:
15
+ print("Loading model...")
16
+ pipe = pipeline(
17
+ "text-generation",
18
+ model=model_name,
19
+ device_map="sequential",
20
+ torch_dtype=torch.float16,
21
+ trust_remote_code=True,
22
+ truncation=True,
23
+ max_new_tokens=2048,
24
+ model_kwargs={
25
+ "low_cpu_mem_usage": True,
26
+ "offload_folder": "offload"
27
+ }
28
+ )
29
+ model_pipeline = pipe
30
+ model_loaded = True
31
+ print("Model loaded successfully.")
32
+ else:
33
+ print("Model already loaded.")
34
+
35
+ def check_model_status():
36
+ """Check if the model is loaded and reload if necessary."""
37
+ global model_loaded
38
+ if not model_loaded:
39
+ print("Model not loaded. Reloading...")
40
+ load_model()
41
+ return model_loaded
42
+
43
+ def chat(message, history):
44
+ global model_pipeline
45
+
46
+ # Ensure the model is loaded before proceeding
47
+ if not check_model_status():
48
+ return "Model is not ready. Please try again later."
49
+
50
+ prompt = f"Human: {message}\n\nAssistant:"
51
+
52
+ # Generate response using the pre-loaded model
53
+ response = model_pipeline(
54
+ prompt,
55
+ max_new_tokens=2048,
56
+ temperature=0.7,
57
+ do_sample=True,
58
+ truncation=True,
59
+ pad_token_id=50256
60
+ )
61
+
62
+ try:
63
+ bot_text = response[0]["generated_text"]
64
+ bot_text = bot_text.split("Assistant:")[-1].strip()
65
+ if "</think>" in bot_text:
66
+ bot_text = bot_text.split("</think>")[-1].strip()
67
+ except Exception as e:
68
+ bot_text = f"Sorry, there was a problem generating the response: {str(e)}"
69
+
70
+ return bot_text
71
+
72
+ def reload_model_button():
73
+ """Reload the model manually via a button."""
74
+ global model_loaded
75
+ model_loaded = False
76
+ load_model()
77
+ return "Model reloaded successfully."
78
+
79
+ # Gradio Interface
80
+ with gr.Blocks() as demo:
81
+ gr.Markdown("# DeepSeek-R1 Chatbot")
82
+ gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•œ ๋Œ€ํ™” ํ…Œ์ŠคํŠธ์šฉ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.")
83
+
84
+ with gr.Row():
85
+ chatbot = gr.Chatbot(height=600)
86
+ textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
87
+
88
+ with gr.Row():
89
+ send_button = gr.Button("Send")
90
+ clear_button = gr.Button("Clear")
91
+ reload_button = gr.Button("Reload Model")
92
+
93
+ status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
94
+
95
+ def update_status():
96
+ """Update the model status text."""
97
+ if model_loaded:
98
+ return "Model is loaded and ready."
99
+ else:
100
+ return "Model is not loaded."
101
+
102
+ def respond(message, chat_history):
103
+ bot_message = chat(message, chat_history)
104
+ chat_history.append((message, bot_message))
105
+ return "", chat_history
106
+
107
+ send_button.click(respond, inputs=[textbox, chatbot], outputs=[textbox, chatbot])
108
+ clear_button.click(lambda: [], None, chatbot)
109
+ reload_button.click(reload_model_button, None, status_text)
110
+
111
+ # Periodically check and update model status
112
+ demo.load(update_status, None, status_text, every=5)
113
+
114
+ # Load the model when the server starts
115
+ if __name__ == "__main__":
116
+ load_model() # Pre-load the model
117
+ demo.launch(server_name="0.0.0.0")