sdafd's picture
Create app.py
96e31d2 verified
raw
history blame
3.8 kB
from fastapi import FastAPI, HTTPException
from fastapi.responses import StreamingResponse
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import threading
import time
app = FastAPI()
# Global variables to store the model and tokenizer
model = None
tokenizer = None
model_loading_lock = threading.Lock()
model_loaded = False # Status flag to indicate if the model is loaded
def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"):
global model, tokenizer, model_loaded
with model_loading_lock:
if not model_loaded:
print("Loading model...")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="sequential",
torch_dtype=torch.float16,
trust_remote_code=True,
low_cpu_mem_usage=True,
offload_folder="offload"
)
model_loaded = True
print("Model loaded successfully.")
else:
print("Model already loaded.")
def check_model_status():
"""Check if the model is loaded and reload if necessary."""
global model_loaded
if not model_loaded:
print("Model not loaded. Reloading...")
load_model()
return model_loaded
@app.post("/chat")
async def chat_endpoint(message: str, temperature: float = 0.7, max_new_tokens: int = 2048):
global model, tokenizer
# Ensure the model is loaded before proceeding
if not check_model_status():
raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.")
stop_tokens = ["|im_end|"]
prompt = f"Human: {message}\n\nAssistant:"
# Tokenize the input
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Stream the response
start_time = time.time()
token_count = 0
# Create a TextStreamer for token streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
input_ids=inputs.input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
streamer=streamer # Use the TextStreamer here
)
# Start generation in a separate thread
threading.Thread(target=model.generate, kwargs=generate_kwargs).start()
def generate_response():
outputs = []
for new_token in streamer:
outputs.append(new_token)
token_count += 1
# Calculate tokens per second
elapsed_time = time.time() - start_time
tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0
# Yield the current output and token status
yield f"data: {new_token}\n\n"
if any(stop_token in new_token for stop_token in stop_tokens):
break
return StreamingResponse(generate_response(), media_type="text/event-stream")
@app.post("/reload-model")
async def reload_model():
"""Reload the model manually via an API endpoint."""
global model_loaded
model_loaded = False
load_model()
return {"message": "Model reloaded successfully."}
@app.get("/status")
async def get_model_status():
"""Check the status of the model."""
status = "Model is loaded and ready." if model_loaded else "Model is not loaded."
return {"status": status}
# Load the model when the server starts
if __name__ == "__main__":
load_model() # Pre-load the model
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)