Update app.py
Browse files
app.py
CHANGED
@@ -10,9 +10,9 @@ import logging
|
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
-
app = FastAPI(title="
|
14 |
|
15 |
-
# Request model
|
16 |
class TextRequest(BaseModel):
|
17 |
inputs: str
|
18 |
system_prompt: Optional[str] = None
|
@@ -33,34 +33,43 @@ model = None
|
|
33 |
@app.on_event("startup")
|
34 |
async def load_model():
|
35 |
global model
|
|
|
|
|
|
|
|
|
36 |
try:
|
37 |
-
|
|
|
|
|
|
|
|
|
38 |
model = AutoModelForCausalLM.from_pretrained(
|
39 |
-
|
40 |
-
model_file=
|
41 |
model_type="llama",
|
42 |
-
gpu_layers=0,
|
43 |
context_length=2048,
|
44 |
-
threads=os.cpu_count()
|
45 |
)
|
46 |
logger.info("Model loaded successfully!")
|
47 |
except Exception as e:
|
48 |
logger.error(f"Failed to load model: {e}")
|
|
|
49 |
raise e
|
50 |
|
51 |
@app.post("/generate", response_model=TextResponse)
|
52 |
async def generate_text(request: TextRequest):
|
53 |
if model is None:
|
54 |
-
raise HTTPException(status_code=
|
55 |
|
56 |
try:
|
57 |
-
#
|
58 |
if request.system_prompt:
|
59 |
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
|
60 |
else:
|
61 |
full_prompt = request.inputs
|
62 |
|
63 |
-
# Generate text
|
64 |
generated_text = model(
|
65 |
full_prompt,
|
66 |
max_new_tokens=request.max_tokens,
|
@@ -71,7 +80,7 @@ async def generate_text(request: TextRequest):
|
|
71 |
stop=request.stop or []
|
72 |
)
|
73 |
|
74 |
-
#
|
75 |
if "Assistant:" in generated_text:
|
76 |
generated_text = generated_text.split("Assistant:")[-1].strip()
|
77 |
|
@@ -83,6 +92,8 @@ async def generate_text(request: TextRequest):
|
|
83 |
|
84 |
@app.get("/health")
|
85 |
async def health_check():
|
|
|
|
|
86 |
return {"status": "healthy", "model_loaded": model is not None}
|
87 |
|
88 |
@app.get("/")
|
|
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
+
app = FastAPI(title="Gema 4B Model API", version="1.0.0")
|
14 |
|
15 |
+
# Request model
|
16 |
class TextRequest(BaseModel):
|
17 |
inputs: str
|
18 |
system_prompt: Optional[str] = None
|
|
|
33 |
@app.on_event("startup")
|
34 |
async def load_model():
|
35 |
global model
|
36 |
+
# Define the local model path
|
37 |
+
model_path = "./model"
|
38 |
+
model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
|
39 |
+
|
40 |
try:
|
41 |
+
if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)):
|
42 |
+
raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
|
43 |
+
|
44 |
+
logger.info(f"Loading model from local path: {model_path}")
|
45 |
+
# Load the model from the local directory downloaded during the Docker build
|
46 |
model = AutoModelForCausalLM.from_pretrained(
|
47 |
+
model_path, # Load from the local folder
|
48 |
+
model_file=model_file, # Specify the GGUF file name
|
49 |
model_type="llama",
|
50 |
+
gpu_layers=0,
|
51 |
context_length=2048,
|
52 |
+
threads=os.cpu_count() or 1
|
53 |
)
|
54 |
logger.info("Model loaded successfully!")
|
55 |
except Exception as e:
|
56 |
logger.error(f"Failed to load model: {e}")
|
57 |
+
# Raising the exception will prevent the app from starting if the model fails to load
|
58 |
raise e
|
59 |
|
60 |
@app.post("/generate", response_model=TextResponse)
|
61 |
async def generate_text(request: TextRequest):
|
62 |
if model is None:
|
63 |
+
raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please try again later.")
|
64 |
|
65 |
try:
|
66 |
+
# Create prompt
|
67 |
if request.system_prompt:
|
68 |
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
|
69 |
else:
|
70 |
full_prompt = request.inputs
|
71 |
|
72 |
+
# Generate text with parameters from the request
|
73 |
generated_text = model(
|
74 |
full_prompt,
|
75 |
max_new_tokens=request.max_tokens,
|
|
|
80 |
stop=request.stop or []
|
81 |
)
|
82 |
|
83 |
+
# Clean up the response
|
84 |
if "Assistant:" in generated_text:
|
85 |
generated_text = generated_text.split("Assistant:")[-1].strip()
|
86 |
|
|
|
92 |
|
93 |
@app.get("/health")
|
94 |
async def health_check():
|
95 |
+
# The health check now also implicitly checks if the model has been loaded
|
96 |
+
# because a failure in load_model will stop the app from running.
|
97 |
return {"status": "healthy", "model_loaded": model is not None}
|
98 |
|
99 |
@app.get("/")
|