Spaces:
Sleeping
Sleeping
File size: 5,424 Bytes
5eb7c5e 6ea0840 5eb7c5e 881f4c4 5eb7c5e 6ea0840 5eb7c5e 6ea0840 5eb7c5e 6ea0840 5eb7c5e 6ea0840 5eb7c5e 8858101 6ea0840 5eb7c5e 8858101 5eb7c5e 6ea0840 8858101 f56e17f 5eb7c5e 53639d5 704f6d2 53639d5 6ea0840 b576940 6ea0840 b576940 6ea0840 b576940 f56e17f 6ea0840 b576940 6ea0840 5eb7c5e 881f4c4 8858101 881f4c4 5eb7c5e 8858101 5eb7c5e 8858101 5eb7c5e 8858101 5eb7c5e 8858101 f56e17f 8858101 f56e17f 8858101 5eb7c5e 8858101 5eb7c5e 881f4c4 5eb7c5e 881f4c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import gradio as gr
import threading
import time
# Global variables to store the model and tokenizer
model = None
tokenizer = None
model_loading_lock = threading.Lock()
model_loaded = False # Status flag to indicate if the model is loaded
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
global model, tokenizer, model_loaded
with model_loading_lock:
if not model_loaded:
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="sequential",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True,
offload_folder="offload"
)
model_loaded = True
print("Model loaded successfully.")
else:
print("Model already loaded.")
def check_model_status():
"""Check if the model is loaded and reload if necessary."""
global model_loaded
if not model_loaded:
print("Model not loaded. Reloading...")
load_model()
return model_loaded
def chat(message, history, temperature, max_new_tokens):
global model, tokenizer
stop_tokens = ["\n", "|im_end|"]
# Ensure the model is loaded before proceeding
if not check_model_status():
yield "Model is not ready. Please try again later."
return
prompt = f"Human: {message}\n\nAssistant:"
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Stream the response
start_time = time.time()
token_count = 0
# Create a TextStreamer for token streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer # Use the TextStreamer here
)
# Create and start the thread with the model.generate function
t = threading.Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for new_token in streamer:
outputs.append(new_token)
token_count += 1
# Calculate tokens per second
elapsed_time = time.time() - start_time
tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
# Update the token status
token_status_value = f"Tokens Generated: {token_count}, Tokens/Second: {tokens_per_second:.2f}"
yield "".join(outputs), token_status_value
if any(stop_token in new_token for stop_token in stop_tokens):
break
def reload_model_button():
"""Reload the model manually via a button."""
global model_loaded
model_loaded = False
load_model()
return "Model reloaded successfully."
# Function to periodically update the status text
def update_status_periodically(status_text):
while True:
time.sleep(5) # Update every 5 seconds
status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
status_text.value = status # Update the value directly
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# DeepSeek-R1 Chatbot")
gr.Markdown("DeepSeek-R1-Distill-Qwen-1.5B ๋ชจ๋ธ์ ์ฌ์ฉํ ๋ํ ํ
์คํธ์ฉ ๋ฐ๋ชจ์
๋๋ค.")
with gr.Row():
chatbot = gr.Chatbot(height=600)
textbox = gr.Textbox(placeholder="Enter your message...", container=False, scale=7)
with gr.Row():
send_button = gr.Button("Send")
clear_button = gr.Button("Clear")
reload_button = gr.Button("Reload Model")
with gr.Row():
temperature_slider = gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature")
max_tokens_slider = gr.Slider(minimum=32, maximum=2048, value=2048, step=32, label="Max New Tokens")
status_text = gr.Textbox(label="Model Status", value="Model not loaded yet.", interactive=False)
token_status = gr.Textbox(label="Token Generation Status", value="", interactive=False)
def respond(message, chat_history, temperature, max_new_tokens):
bot_message = ""
for partial_response, token_status_value in chat(message, chat_history, temperature, max_new_tokens):
bot_message = partial_response
token_status.update(value=token_status_value) # Update token generation status
yield "", chat_history + [(message, bot_message)]
send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot])
clear_button.click(lambda: [], None, chatbot)
reload_button.click(reload_model_button, None, status_text)
# Start a background thread to update the status text periodically
threading.Thread(target=update_status_periodically, args=(status_text,), daemon=True).start()
# Load the model when the server starts
if __name__ == "__main__":
load_model() # Pre-load the model
demo.launch(server_name="0.0.0.0") |