from fastapi import FastAPI, HTTPException from pydantic import BaseModel from ctransformers import AutoModelForCausalLM import os import uvicorn from typing import Optional, List import logging # Set up loggings logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Gema 4B Model API", version="1.0.0") # Request model class TextRequest(BaseModel): inputs: str system_prompt: Optional[str] = None max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.7 top_k: Optional[int] = 50 top_p: Optional[float] = 0.9 repeat_penalty: Optional[float] = 1.1 stop: Optional[List[str]] = None # Response model class TextResponse(BaseModel): generated_text: str # Global model variable model = None @app.on_event("startup") async def load_model(): global model # Define the local model path model_path = "./model" model_file = "gema-4b-indra10k-model1-q4_k_m.gguf" try: if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)): raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.") logger.info(f"Loading model from local path: {model_path}") # Load the model from the local directory downloaded during the Docker build model = AutoModelForCausalLM.from_pretrained( model_path, # Load from the local folder model_file=model_file, # Specify the GGUF file name model_type="llama", gpu_layers=0, context_length=2048, threads=os.cpu_count() or 1 ) logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Failed to load model: {e}") # Raising the exception will prevent the app from starting if the model fails to load raise e @app.post("/generate", response_model=TextResponse) async def generate_text(request: TextRequest): if model is None: raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please try again later.") try: # Create prompt if request.system_prompt: full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:" else: full_prompt = request.inputs # Generate text with parameters from the request generated_text = model( full_prompt, max_new_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, repetition_penalty=request.repeat_penalty, stop=request.stop or [] ) # Clean up the response if "Assistant:" in generated_text: generated_text = generated_text.split("Assistant:")[-1].strip() return TextResponse(generated_text=generated_text) except Exception as e: logger.error(f"Generation error: {e}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") @app.get("/health") async def health_check(): # The health check now also implicitly checks if the model has been loaded # because a failure in load_model will stop the app from running. return {"status": "healthy", "model_loaded": model is not None} @app.get("/") async def root(): return {"message": "Gema 4B Model API", "docs": "/docs"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")