File size: 3,309 Bytes
6d2ea02 0504f7a 6d2ea02 b24565b 6d2ea02 b24565b 6d2ea02 b24565b 6d2ea02 0504f7a b10a1be 6d2ea02 0504f7a b10a1be 0504f7a b24565b 0504f7a 6d2ea02 0504f7a 6d2ea02 b24565b 0504f7a b24565b 6d2ea02 0504f7a 6d2ea02 0504f7a 6d2ea02 0504f7a 6d2ea02 0504f7a 6d2ea02 0504f7a 6d2ea02 8f0965e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
import os
import uvicorn
from typing import Optional, List
import logging
from contextlib import asynccontextmanager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variable
model = None
# Lifespan manager to load the model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
global model
model_gguf_path = os.path.join("./model", "gema-4b-indra10k-model1-q4_k_m.gguf")
try:
if not os.path.exists(model_gguf_path):
raise RuntimeError(f"Model file not found at: {model_gguf_path}")
logger.info(f"Loading model from: {model_gguf_path}")
# Load the model using llama-cpp-python
model = Llama(
model_path=model_gguf_path,
n_ctx=2048, # Context length
n_gpu_layers=0, # Set to a positive number if GPU is available
n_threads=os.cpu_count() or 1,
verbose=True,
)
logger.info("Model loaded successfully using llama-cpp-python!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
yield
# Cleanup code if needed on shutdown
logger.info("Application is shutting down.")
app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan)
# Request model
class TextRequest(BaseModel):
inputs: str
system_prompt: Optional[str] = None
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.7
top_k: Optional[int] = 50
top_p: Optional[float] = 0.9
repeat_penalty: Optional[float] = 1.1
stop: Optional[List[str]] = None
# Response model
class TextResponse(BaseModel):
generated_text: str
@app.post("/generate", response_model=TextResponse)
async def generate_text(request: TextRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model is not ready or failed to load.")
try:
# Create prompt
if request.system_prompt:
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
else:
full_prompt = request.inputs
# Generate text using llama-cpp-python syntax
output = model(
prompt=full_prompt,
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repeat_penalty=request.repeat_penalty,
stop=request.stop or []
)
# Extract the generated text from the response structure
generated_text = output['choices'][0]['text'].strip()
return TextResponse(generated_text=generated_text)
except Exception as e:
logger.error(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model is not None}
@app.get("/")
async def root():
return {"message": "Gema 4B Model API", "docs": "/docs"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |