deepseek-r1-cpu / app.py
sdafd's picture
Create app.py
5eb7c5e verified
raw
history blame
3.82 kB
import torch
from transformers import pipeline
import gradio as gr
import threading
# Global variable to store the model pipeline
model_pipeline = None
model_loading_lock = threading.Lock()
model_loaded = False # Status flag to indicate if the model is loaded
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
global model_pipeline, model_loaded
with model_loading_lock:
if not model_loaded:
print("Loading model...")
pipe = pipeline(
"text-generation",
model=model_name,
device_map="sequential",
torch_dtype=torch.float16,
trust_remote_code=True,
truncation=True,
max_new_tokens=2048,
model_kwargs={
"low_cpu_mem_usage": True,
"offload_folder": "offload"
}
)
model_pipeline = pipe
model_loaded = True
print("Model loaded successfully.")
else:
print("Model already loaded.")
def check_model_status():
"""Check if the model is loaded and reload if necessary."""
global model_loaded
if not model_loaded:
print("Model not loaded. Reloading...")
load_model()
return model_loaded
def chat(message, history):
global model_pipeline
# Ensure the model is loaded before proceeding
if not check_model_status():
return "Model is not ready. Please try again later."
prompt = f"Human: {message}\n\nAssistant:"
# Generate response using the pre-loaded model
response = model_pipeline(
prompt,
max_new_tokens=2048,
temperature=0.7,
do_sample=True,
truncation=True,
pad_token_id=50256
)
try:
bot_text = response[0]["generated_text"]
bot_text = bot_text.split("Assistant:")[-1].strip()
if "</think>" in bot_text:
bot_text = bot_text.split("</think>")[-1].strip()
except Exception as e:
bot_text = f"Sorry, there was a problem generating the response: {str(e)}"
return bot_text
def reload_model_button():
"""Reload the model manually via a button."""
global model_loaded
model_loaded = False
load_model()
return "Model reloaded successfully."
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("# DeepSeek-R1 Chatbot")
gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B λͺ¨λΈμ„ μ‚¬μš©ν•œ λŒ€ν™” ν…ŒμŠ€νŠΈμš© 데λͺ¨μž…λ‹ˆλ‹€.")
with gr.Row():
chatbot = gr.Chatbot(height=600)
textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
with gr.Row():
send_button = gr.Button("Send")
clear_button = gr.Button("Clear")
reload_button = gr.Button("Reload Model")
status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
def update_status():
"""Update the model status text."""
if model_loaded:
return "Model is loaded and ready."
else:
return "Model is not loaded."
def respond(message, chat_history):
bot_message = chat(message, chat_history)
chat_history.append((message, bot_message))
return "", chat_history
send_button.click(respond, inputs=[textbox, chatbot], outputs=[textbox, chatbot])
clear_button.click(lambda: [], None, chatbot)
reload_button.click(reload_model_button, None, status_text)
# Periodically check and update model status
demo.load(update_status, None, status_text, every=5)
# Load the model when the server starts
if __name__ == "__main__":
load_model() # Pre-load the model
demo.launch(server_name="0.0.0.0")