from fastapi import FastAPI, HTTPException from fastapi.responses import StreamingResponse from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import torch import threading import time app = FastAPI() # Global variables to store the model and tokenizer model = None tokenizer = None model_loading_lock = threading.Lock() model_loaded = False # Status flag to indicate if the model is loaded def load_model(model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"): global model, tokenizer, model_loaded with model_loading_lock: if not model_loaded: print("Loading model...") tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( model_name, device_map="sequential", torch_dtype=torch.float16, trust_remote_code=True, low_cpu_mem_usage=True, offload_folder="offload" ) model_loaded = True print("Model loaded successfully.") else: print("Model already loaded.") def check_model_status(): """Check if the model is loaded and reload if necessary.""" global model_loaded if not model_loaded: print("Model not loaded. Reloading...") load_model() return model_loaded @app.post("/chat") async def chat_endpoint(message: str, temperature: float = 0.7, max_new_tokens: int = 2048): global model, tokenizer # Ensure the model is loaded before proceeding if not check_model_status(): raise HTTPException(status_code=503, detail="Model is not ready. Please try again later.") stop_tokens = ["|im_end|"] prompt = f"Human: {message}\n\nAssistant:" # Tokenize the input inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Stream the response start_time = time.time() token_count = 0 # Create a TextStreamer for token streaming streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( input_ids=inputs.input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, pad_token_id=tokenizer.eos_token_id, streamer=streamer # Use the TextStreamer here ) # Start generation in a separate thread threading.Thread(target=model.generate, kwargs=generate_kwargs).start() def generate_response(): outputs = [] for new_token in streamer: outputs.append(new_token) token_count += 1 # Calculate tokens per second elapsed_time = time.time() - start_time tokens_per_second = token_count / elapsed_time if elapsed_time > 0 else 0 # Yield the current output and token status yield f"data: {new_token}\n\n" if any(stop_token in new_token for stop_token in stop_tokens): break return StreamingResponse(generate_response(), media_type="text/event-stream") @app.post("/reload-model") async def reload_model(): """Reload the model manually via an API endpoint.""" global model_loaded model_loaded = False load_model() return {"message": "Model reloaded successfully."} @app.get("/status") async def get_model_status(): """Check the status of the model.""" status = "Model is loaded and ready." if model_loaded else "Model is not loaded." return {"status": status} # Load the model when the server starts if __name__ == "__main__": load_model() # Pre-load the model import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)