llm-apiku / app.py
Dnfs's picture
Update app.py
0504f7a verified
raw
history blame
3.31 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
import os
import uvicorn
from typing import Optional, List
import logging
from contextlib import asynccontextmanager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variable
model = None
# Lifespan manager to load the model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
model_gguf_path = os.path.join("./model", "gema-4b-indra10k-model1-q4_k_m.gguf")
try:
if not os.path.exists(model_gguf_path):
raise RuntimeError(f"Model file not found at: {model_gguf_path}")
logger.info(f"Loading model from: {model_gguf_path}")
# Load the model using llama-cpp-python
model = Llama(
model_path=model_gguf_path,
n_ctx=2048, # Context length
n_gpu_layers=0, # Set to a positive number if GPU is available
n_threads=os.cpu_count() or 1,
verbose=True,
)
logger.info("Model loaded successfully using llama-cpp-python!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
yield
# Cleanup code if needed on shutdown
logger.info("Application is shutting down.")
app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan)
# Request model
class TextRequest(BaseModel):
inputs: str
system_prompt: Optional[str] = None
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.7
top_k: Optional[int] = 50
top_p: Optional[float] = 0.9
repeat_penalty: Optional[float] = 1.1
stop: Optional[List[str]] = None
# Response model
class TextResponse(BaseModel):
generated_text: str
@app.post("/generate", response_model=TextResponse)
async def generate_text(request: TextRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model is not ready or failed to load.")
try:
# Create prompt
if request.system_prompt:
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
else:
full_prompt = request.inputs
# Generate text using llama-cpp-python syntax
output = model(
prompt=full_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repeat_penalty=request.repeat_penalty,
stop=request.stop or []
)
# Extract the generated text from the response structure
generated_text = output['choices'][0]['text'].strip()
return TextResponse(generated_text=generated_text)
except Exception as e:
logger.error(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model is not None}
@app.get("/")
async def root():
return {"message": "Gema 4B Model API", "docs": "/docs"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")