sdafd commited on
Commit
6ea0840
·
verified ·
1 Parent(s): 704f6d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -34
app.py CHANGED
@@ -1,33 +1,29 @@
1
  import torch
2
- from transformers import pipeline, TextIteratorStreamer
3
  import gradio as gr
4
  import threading
5
  import time
6
 
7
- # Global variable to store the model pipeline
8
- model_pipeline = None
 
9
  model_loading_lock = threading.Lock()
10
  model_loaded = False # Status flag to indicate if the model is loaded
11
 
12
  def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
13
- global model_pipeline, model_loaded
14
  with model_loading_lock:
15
  if not model_loaded:
16
  print("Loading model...")
17
- pipe = pipeline(
18
- "text-generation",
19
- model=model_name,
20
  device_map="sequential",
21
  torch_dtype=torch.float16,
22
  trust_remote_code=True,
23
- truncation=True,
24
- max_new_tokens=2048,
25
- model_kwargs={
26
- "low_cpu_mem_usage": True,
27
- "offload_folder": "offload"
28
- }
29
  )
30
- model_pipeline = pipe
31
  model_loaded = True
32
  print("Model loaded successfully.")
33
  else:
@@ -42,9 +38,9 @@ def check_model_status():
42
  return model_loaded
43
 
44
  def chat(message, history, temperature, max_new_tokens):
45
- global model_pipeline
46
- stop_tokens = ["<|endoftext|>", "<|im_end|>","|im_end|"]
47
-
48
  # Ensure the model is loaded before proceeding
49
  if not check_model_status():
50
  yield "Model is not ready. Please try again later."
@@ -52,36 +48,35 @@ def chat(message, history, temperature, max_new_tokens):
52
 
53
  prompt = f"Human: {message}\n\nAssistant:"
54
 
 
 
 
55
  # Stream the response
56
  start_time = time.time()
57
 
58
  # Create a TextStreamer for token streaming
59
- tokenizer = model_pipeline.tokenizer
60
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
61
 
62
-
63
-
64
- pipeline_kwargs = dict(
65
- prompt=prompt,
66
  max_new_tokens=max_new_tokens,
67
  temperature=temperature,
68
  do_sample=True,
69
- truncation=True,
70
  pad_token_id=tokenizer.eos_token_id,
71
  streamer=streamer # Use the TextStreamer here
72
  )
73
 
74
- # Create and start the thread with the model_pipeline function
75
- t = threading.Thread(target=lambda: model_pipeline(**pipeline_kwargs))
76
  t.start()
77
-
 
78
  for new_token in streamer:
79
- print(new_token)
80
  outputs.append(new_token)
81
- if new_token in stop_tokens:
82
-
83
  break
84
- yield "".join(outputs), "not implemented"
 
85
  def reload_model_button():
86
  """Reload the model manually via a button."""
87
  global model_loaded
@@ -119,11 +114,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
119
 
120
  def respond(message, chat_history, temperature, max_new_tokens):
121
  bot_message = ""
122
- status = ""
123
- for partial_response, partial_status in chat(message, chat_history, temperature, max_new_tokens):
124
  bot_message = partial_response
125
- status = partial_status
126
- token_status.update(value=status)
127
  yield "", chat_history + [(message, bot_message)]
128
 
129
  send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot])
 
1
  import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
  import gradio as gr
4
  import threading
5
  import time
6
 
7
+ # Global variables to store the model and tokenizer
8
+ model = None
9
+ tokenizer = None
10
  model_loading_lock = threading.Lock()
11
  model_loaded = False # Status flag to indicate if the model is loaded
12
 
13
  def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
14
+ global model, tokenizer, model_loaded
15
  with model_loading_lock:
16
  if not model_loaded:
17
  print("Loading model...")
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
  device_map="sequential",
22
  torch_dtype=torch.float16,
23
  trust_remote_code=True,
24
+ low_cpu_mem_usage=True,
25
+ offload_folder="offload"
 
 
 
 
26
  )
 
27
  model_loaded = True
28
  print("Model loaded successfully.")
29
  else:
 
38
  return model_loaded
39
 
40
  def chat(message, history, temperature, max_new_tokens):
41
+ global model, tokenizer
42
+ stop_tokens = ["\n", "|im_end|"]
43
+
44
  # Ensure the model is loaded before proceeding
45
  if not check_model_status():
46
  yield "Model is not ready. Please try again later."
 
48
 
49
  prompt = f"Human: {message}\n\nAssistant:"
50
 
51
+ # Tokenize the input
52
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
53
+
54
  # Stream the response
55
  start_time = time.time()
56
 
57
  # Create a TextStreamer for token streaming
 
58
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
59
 
60
+ generate_kwargs = dict(
61
+ input_ids=inputs.input_ids,
 
 
62
  max_new_tokens=max_new_tokens,
63
  temperature=temperature,
64
  do_sample=True,
 
65
  pad_token_id=tokenizer.eos_token_id,
66
  streamer=streamer # Use the TextStreamer here
67
  )
68
 
69
+ # Create and start the thread with the model.generate function
70
+ t = threading.Thread(target=model.generate, kwargs=generate_kwargs)
71
  t.start()
72
+
73
+ outputs = []
74
  for new_token in streamer:
 
75
  outputs.append(new_token)
76
+ if any(stop_token in new_token for stop_token in stop_tokens):
 
77
  break
78
+ yield "".join(outputs)
79
+
80
  def reload_model_button():
81
  """Reload the model manually via a button."""
82
  global model_loaded
 
114
 
115
  def respond(message, chat_history, temperature, max_new_tokens):
116
  bot_message = ""
117
+ for partial_response in chat(message, chat_history, temperature, max_new_tokens):
 
118
  bot_message = partial_response
 
 
119
  yield "", chat_history + [(message, bot_message)]
120
 
121
  send_button.click(respond, inputs=[textbox, chatbot, temperature_slider, max_tokens_slider], outputs=[textbox, chatbot])