Update app.py
Browse files
app.py
CHANGED
@@ -5,35 +5,20 @@ import os
|
|
5 |
import uvicorn
|
6 |
from typing import Optional, List
|
7 |
import logging
|
|
|
8 |
|
9 |
-
# Set up
|
10 |
logging.basicConfig(level=logging.INFO)
|
11 |
logger = logging.getLogger(__name__)
|
12 |
|
13 |
-
app = FastAPI(title="Gema 4B Model API", version="1.0.0")
|
14 |
-
|
15 |
-
# Request model
|
16 |
-
class TextRequest(BaseModel):
|
17 |
-
inputs: str
|
18 |
-
system_prompt: Optional[str] = None
|
19 |
-
max_tokens: Optional[int] = 512
|
20 |
-
temperature: Optional[float] = 0.7
|
21 |
-
top_k: Optional[int] = 50
|
22 |
-
top_p: Optional[float] = 0.9
|
23 |
-
repeat_penalty: Optional[float] = 1.1
|
24 |
-
stop: Optional[List[str]] = None
|
25 |
-
|
26 |
-
# Response model
|
27 |
-
class TextResponse(BaseModel):
|
28 |
-
generated_text: str
|
29 |
-
|
30 |
# Global model variable
|
31 |
model = None
|
32 |
|
33 |
-
|
34 |
-
|
|
|
|
|
35 |
global model
|
36 |
-
# Define the local model path
|
37 |
model_path = "./model"
|
38 |
model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
|
39 |
|
@@ -42,11 +27,12 @@ async def load_model():
|
|
42 |
raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
|
43 |
|
44 |
logger.info(f"Loading model from local path: {model_path}")
|
45 |
-
|
|
|
46 |
model = AutoModelForCausalLM.from_pretrained(
|
47 |
-
model_path,
|
48 |
-
model_file=model_file,
|
49 |
-
model_type="
|
50 |
gpu_layers=0,
|
51 |
context_length=2048,
|
52 |
threads=os.cpu_count() or 1
|
@@ -54,22 +40,44 @@ async def load_model():
|
|
54 |
logger.info("Model loaded successfully!")
|
55 |
except Exception as e:
|
56 |
logger.error(f"Failed to load model: {e}")
|
57 |
-
# Raising
|
58 |
raise e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
@app.post("/generate", response_model=TextResponse)
|
61 |
async def generate_text(request: TextRequest):
|
62 |
if model is None:
|
63 |
-
raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please
|
64 |
|
65 |
try:
|
66 |
-
# Create prompt
|
67 |
if request.system_prompt:
|
68 |
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
|
69 |
else:
|
70 |
full_prompt = request.inputs
|
71 |
|
72 |
-
# Generate text with parameters from the request
|
73 |
generated_text = model(
|
74 |
full_prompt,
|
75 |
max_new_tokens=request.max_tokens,
|
@@ -80,7 +88,6 @@ async def generate_text(request: TextRequest):
|
|
80 |
stop=request.stop or []
|
81 |
)
|
82 |
|
83 |
-
# Clean up the response
|
84 |
if "Assistant:" in generated_text:
|
85 |
generated_text = generated_text.split("Assistant:")[-1].strip()
|
86 |
|
@@ -92,8 +99,6 @@ async def generate_text(request: TextRequest):
|
|
92 |
|
93 |
@app.get("/health")
|
94 |
async def health_check():
|
95 |
-
# The health check now also implicitly checks if the model has been loaded
|
96 |
-
# because a failure in load_model will stop the app from running.
|
97 |
return {"status": "healthy", "model_loaded": model is not None}
|
98 |
|
99 |
@app.get("/")
|
|
|
5 |
import uvicorn
|
6 |
from typing import Optional, List
|
7 |
import logging
|
8 |
+
from contextlib import asynccontextmanager
|
9 |
|
10 |
+
# Set up logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Global model variable
|
15 |
model = None
|
16 |
|
17 |
+
# Lifespan manager to load the model on startup
|
18 |
+
@asynccontextmanager
|
19 |
+
async def lifespan(app: FastAPI):
|
20 |
+
# This code runs on startup
|
21 |
global model
|
|
|
22 |
model_path = "./model"
|
23 |
model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
|
24 |
|
|
|
27 |
raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
|
28 |
|
29 |
logger.info(f"Loading model from local path: {model_path}")
|
30 |
+
|
31 |
+
# FIX: Changed model_type from "llama" to "gemma"
|
32 |
model = AutoModelForCausalLM.from_pretrained(
|
33 |
+
model_path,
|
34 |
+
model_file=model_file,
|
35 |
+
model_type="gemma", # This was the main cause of the error
|
36 |
gpu_layers=0,
|
37 |
context_length=2048,
|
38 |
threads=os.cpu_count() or 1
|
|
|
40 |
logger.info("Model loaded successfully!")
|
41 |
except Exception as e:
|
42 |
logger.error(f"Failed to load model: {e}")
|
43 |
+
# Raising an exception during startup will prevent the app from starting
|
44 |
raise e
|
45 |
+
|
46 |
+
yield
|
47 |
+
# This code runs on shutdown (optional)
|
48 |
+
logger.info("Application is shutting down.")
|
49 |
+
|
50 |
+
|
51 |
+
app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan)
|
52 |
+
|
53 |
+
|
54 |
+
# Request model
|
55 |
+
class TextRequest(BaseModel):
|
56 |
+
inputs: str
|
57 |
+
system_prompt: Optional[str] = None
|
58 |
+
max_tokens: Optional[int] = 512
|
59 |
+
temperature: Optional[float] = 0.7
|
60 |
+
top_k: Optional[int] = 50
|
61 |
+
top_p: Optional[float] = 0.9
|
62 |
+
repeat_penalty: Optional[float] = 1.1
|
63 |
+
stop: Optional[List[str]] = None
|
64 |
+
|
65 |
+
# Response model
|
66 |
+
class TextResponse(BaseModel):
|
67 |
+
generated_text: str
|
68 |
+
|
69 |
|
70 |
@app.post("/generate", response_model=TextResponse)
|
71 |
async def generate_text(request: TextRequest):
|
72 |
if model is None:
|
73 |
+
raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please check logs.")
|
74 |
|
75 |
try:
|
|
|
76 |
if request.system_prompt:
|
77 |
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
|
78 |
else:
|
79 |
full_prompt = request.inputs
|
80 |
|
|
|
81 |
generated_text = model(
|
82 |
full_prompt,
|
83 |
max_new_tokens=request.max_tokens,
|
|
|
88 |
stop=request.stop or []
|
89 |
)
|
90 |
|
|
|
91 |
if "Assistant:" in generated_text:
|
92 |
generated_text = generated_text.split("Assistant:")[-1].strip()
|
93 |
|
|
|
99 |
|
100 |
@app.get("/health")
|
101 |
async def health_check():
|
|
|
|
|
102 |
return {"status": "healthy", "model_loaded": model is not None}
|
103 |
|
104 |
@app.get("/")
|