|
from fastapi import FastAPI, HTTPException |
|
from pydantic import BaseModel |
|
from ctransformers import AutoModelForCausalLM |
|
import os |
|
import uvicorn |
|
from typing import Optional, List |
|
import logging |
|
from contextlib import asynccontextmanager |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model = None |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
|
|
global model |
|
model_path = "./model" |
|
model_file = "gema-4b-indra10k-model1-q4_k_m.gguf" |
|
|
|
try: |
|
if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)): |
|
raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.") |
|
|
|
logger.info(f"Loading model from local path: {model_path}") |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_path, |
|
model_file=model_file, |
|
model_type="gemma", |
|
gpu_layers=0, |
|
context_length=2048, |
|
threads=os.cpu_count() or 1 |
|
) |
|
logger.info("Model loaded successfully!") |
|
except Exception as e: |
|
logger.error(f"Failed to load model: {e}") |
|
|
|
raise e |
|
|
|
yield |
|
|
|
logger.info("Application is shutting down.") |
|
|
|
|
|
app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan) |
|
|
|
|
|
|
|
class TextRequest(BaseModel): |
|
inputs: str |
|
system_prompt: Optional[str] = None |
|
max_tokens: Optional[int] = 512 |
|
temperature: Optional[float] = 0.7 |
|
top_k: Optional[int] = 50 |
|
top_p: Optional[float] = 0.9 |
|
repeat_penalty: Optional[float] = 1.1 |
|
stop: Optional[List[str]] = None |
|
|
|
|
|
class TextResponse(BaseModel): |
|
generated_text: str |
|
|
|
|
|
@app.post("/generate", response_model=TextResponse) |
|
async def generate_text(request: TextRequest): |
|
if model is None: |
|
raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please check logs.") |
|
|
|
try: |
|
if request.system_prompt: |
|
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:" |
|
else: |
|
full_prompt = request.inputs |
|
|
|
generated_text = model( |
|
full_prompt, |
|
max_new_tokens=request.max_tokens, |
|
temperature=request.temperature, |
|
top_p=request.top_p, |
|
top_k=request.top_k, |
|
repetition_penalty=request.repeat_penalty, |
|
stop=request.stop or [] |
|
) |
|
|
|
if "Assistant:" in generated_text: |
|
generated_text = generated_text.split("Assistant:")[-1].strip() |
|
|
|
return TextResponse(generated_text=generated_text) |
|
|
|
except Exception as e: |
|
logger.error(f"Generation error: {e}") |
|
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy", "model_loaded": model is not None} |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "Gema 4B Model API", "docs": "/docs"} |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |