sdafd commited on
Commit
96e31d2
·
verified ·
1 Parent(s): 58636ea

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.responses import StreamingResponse
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ import torch
5
+ import threading
6
+ import time
7
+
8
+ app = FastAPI()
9
+
10
+ # Global variables to store the model and tokenizer
11
+ model = None
12
+ tokenizer = None
13
+ model_loading_lock = threading.Lock()
14
+ model_loaded = False # Status flag to indicate if the model is loaded
15
+
16
+ def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
17
+ global model, tokenizer, model_loaded
18
+ with model_loading_lock:
19
+ if not model_loaded:
20
+ print("Loading model...")
21
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ model_name,
24
+ device_map="sequential",
25
+ torch_dtype=torch.float16,
26
+ trust_remote_code=True,
27
+ low_cpu_mem_usage=True,
28
+ offload_folder="offload"
29
+ )
30
+ model_loaded = True
31
+ print("Model loaded successfully.")
32
+ else:
33
+ print("Model already loaded.")
34
+
35
+ def check_model_status():
36
+ """Check if the model is loaded and reload if necessary."""
37
+ global model_loaded
38
+ if not model_loaded:
39
+ print("Model not loaded. Reloading...")
40
+ load_model()
41
+ return model_loaded
42
+
43
+ @app.post("/chat")
44
+ async def chat_endpoint(message: str, temperature: float = 0.7, max_new_tokens: int = 2048):
45
+ global model, tokenizer
46
+
47
+ # Ensure the model is loaded before proceeding
48
+ if not check_model_status():
49
+ raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
50
+
51
+ stop_tokens = ["|im_end|"]
52
+ prompt = f"Human: {message}\n\nAssistant:"
53
+
54
+ # Tokenize the input
55
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
56
+
57
+ # Stream the response
58
+ start_time = time.time()
59
+ token_count = 0
60
+
61
+ # Create a TextStreamer for token streaming
62
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
63
+
64
+ generate_kwargs = dict(
65
+ input_ids=inputs.input_ids,
66
+ max_new_tokens=max_new_tokens,
67
+ temperature=temperature,
68
+ do_sample=True,
69
+ pad_token_id=tokenizer.eos_token_id,
70
+ streamer=streamer # Use the TextStreamer here
71
+ )
72
+
73
+ # Start generation in a separate thread
74
+ threading.Thread(target=model.generate, kwargs=generate_kwargs).start()
75
+
76
+ def generate_response():
77
+ outputs = []
78
+ for new_token in streamer:
79
+ outputs.append(new_token)
80
+ token_count += 1
81
+
82
+ # Calculate tokens per second
83
+ elapsed_time = time.time() - start_time
84
+ tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
85
+
86
+ # Yield the current output and token status
87
+ yield f"data: {new_token}\n\n"
88
+
89
+ if any(stop_token in new_token for stop_token in stop_tokens):
90
+ break
91
+
92
+ return StreamingResponse(generate_response(), media_type="text/event-stream")
93
+
94
+ @app.post("/reload-model")
95
+ async def reload_model():
96
+ """Reload the model manually via an API endpoint."""
97
+ global model_loaded
98
+ model_loaded = False
99
+ load_model()
100
+ return {"message": "Model reloaded successfully."}
101
+
102
+ @app.get("/status")
103
+ async def get_model_status():
104
+ """Check the status of the model."""
105
+ status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
106
+ return {"status": status}
107
+
108
+ # Load the model when the server starts
109
+ if __name__ == "__main__":
110
+ load_model() # Pre-load the model
111
+ import uvicorn
112
+ uvicorn.run(app, host="0.0.0.0", port=8000)