File size: 3,485 Bytes
6d2ea02 b24565b 6d2ea02 b24565b 6d2ea02 b24565b 6d2ea02 b10a1be 6d2ea02 b10a1be b24565b 6d2ea02 b24565b b10a1be 6d2ea02 b10a1be 6d2ea02 b24565b 6d2ea02 b24565b 6d2ea02 b24565b 6d2ea02 8f0965e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM
import os
import uvicorn
from typing import Optional, List
import logging
from contextlib import asynccontextmanager
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variable
model = None
# Lifespan manager to load the model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
# This code runs on startup
global model
model_path = "./model"
model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
try:
if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)):
raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
logger.info(f"Loading model from local path: {model_path}")
# FIX: Changed model_type from "llama" to "gemma"
model = AutoModelForCausalLM.from_pretrained(
model_path,
model_file=model_file,
model_type="gemma", # This was the main cause of the error
gpu_layers=0,
context_length=2048,
threads=os.cpu_count() or 1
)
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Failed to load model: {e}")
# Raising an exception during startup will prevent the app from starting
raise e
yield
# This code runs on shutdown (optional)
logger.info("Application is shutting down.")
app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan)
# Request model
class TextRequest(BaseModel):
inputs: str
system_prompt: Optional[str] = None
max_tokens: Optional[int] = 512
temperature: Optional[float] = 0.7
top_k: Optional[int] = 50
top_p: Optional[float] = 0.9
repeat_penalty: Optional[float] = 1.1
stop: Optional[List[str]] = None
# Response model
class TextResponse(BaseModel):
generated_text: str
@app.post("/generate", response_model=TextResponse)
async def generate_text(request: TextRequest):
if model is None:
raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please check logs.")
try:
if request.system_prompt:
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
else:
full_prompt = request.inputs
generated_text = model(
full_prompt,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
repetition_penalty=request.repeat_penalty,
stop=request.stop or []
)
if "Assistant:" in generated_text:
generated_text = generated_text.split("Assistant:")[-1].strip()
return TextResponse(generated_text=generated_text)
except Exception as e:
logger.error(f"Generation error: {e}")
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model is not None}
@app.get("/")
async def root():
return {"message": "Gema 4B Model API", "docs": "/docs"}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |