Spaces:
Sleeping
Sleeping
File size: 5,040 Bytes
5eb7c5e 53639d5 5eb7c5e 881f4c4 5eb7c5e 8858101 5eb7c5e b576940 5eb7c5e 8858101 5eb7c5e 8858101 5eb7c5e 53639d5 b576940 5eb7c5e 881f4c4 8858101 881f4c4 5eb7c5e 8858101 5eb7c5e 8858101 5eb7c5e 8858101 5eb7c5e 8858101 5eb7c5e 8858101 5eb7c5e 881f4c4 5eb7c5e 881f4c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import torch
from transformers import pipeline, TextStreamer
import gradio as gr
import threading
import time
# Global variable to store the model pipeline
model_pipeline = None
model_loading_lock = threading.Lock()
model_loaded = False # Status flag to indicate if the model is loaded
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
global model_pipeline, model_loaded
with model_loading_lock:
if not model_loaded:
print("Loading model...")
pipe = pipeline(
"text-generation",
model=model_name,
device_map="sequential",
torch_dtype=torch.float16,
trust_remote_code=True,
truncation=True,
max_new_tokens=2048,
model_kwargs={
"low_cpu_mem_usage": True,
"offload_folder": "offload"
}
)
model_pipeline = pipe
model_loaded = True
print("Model loaded successfully.")
else:
print("Model already loaded.")
def check_model_status():
"""Check if the model is loaded and reload if necessary."""
global model_loaded
if not model_loaded:
print("Model not loaded. Reloading...")
load_model()
return model_loaded
def chat(message, history, temperature, max_new_tokens):
global model_pipeline
stop_tokens = ["<|endoftext|>", "<|im_end|>","|im_end|"]
# Ensure the model is loaded before proceeding
if not check_model_status():
yield "Model is not ready. Please try again later."
return
prompt = f"Human: {message}\n\nAssistant:"
# Stream the response
start_time = time.time()
# Create a TextStreamer for token streaming
tokenizer = model_pipeline.tokenizer
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
pipeline_kwargs = dict(
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
truncation=True,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer # Use the TextStreamer here
)
# Create and start the thread with the model_pipeline function
t = threading.Thread(target=lambda: model_pipeline(**pipeline_kwargs))
t.start()
for new_token in streamer:
print(new_token)
outputs.append(new_token)
if new_token in stop_tokens:
break
yield "".join(outputs), "not implemented"
def reload_model_button():
"""Reload the model manually via a button."""
global model_loaded
model_loaded = False
load_model()
return "Model reloaded successfully."
# Function to periodically update the status text
def update_status_periodically(status_text):
while True:
time.sleep(5) # Update every 5 seconds
status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
status_text.value = status # Update the value directly
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# DeepSeek-R1 Chatbot")
gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B ๋ชจ๋ธ์ ์ฌ์ฉํ ๋ํ ํ
์คํธ์ฉ ๋ฐ๋ชจ์
๋๋ค.")
with gr.Row():
chatbot = gr.Chatbot(height=600)
textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
with gr.Row():
send_button = gr.Button("Send")
clear_button = gr.Button("Clear")
reload_button = gr.Button("Reload Model")
with gr.Row():
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
max_tokens_slider = gr.Slider(minimum=32, maximum=2048, value=2048, step=32, label="Max New Tokens")
status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
token_status = gr.Textbox(label="Token Generation Status", value="", interactive=False)
def respond(message, chat_history, temperature, max_new_tokens):
bot_message = ""
status = ""
for partial_response, partial_status in chat(message, chat_history, temperature, max_new_tokens):
bot_message = partial_response
status = partial_status
token_status.update(value=status)
yield "", chat_history + [(message, bot_message)]
send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot])
clear_button.click(lambda: [], None, chatbot)
reload_button.click(reload_model_button, None, status_text)
# Start a background thread to update the status text periodically
threading.Thread(target=update_status_periodically, args=(status_text,), daemon=True).start()
# Load the model when the server starts
if __name__ == "__main__":
load_model() # Pre-load the model
demo.launch(server_name="0.0.0.0") |