Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
-
from
|
4 |
import os
|
5 |
import uvicorn
|
6 |
from typing import Optional, List
|
@@ -17,34 +17,30 @@ model = None
|
|
17 |
# Lifespan manager to load the model on startup
|
18 |
@asynccontextmanager
|
19 |
async def lifespan(app: FastAPI):
|
20 |
-
# This code runs on startup
|
21 |
global model
|
22 |
-
|
23 |
-
model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
|
24 |
|
25 |
try:
|
26 |
-
if not os.path.exists(
|
27 |
-
|
28 |
|
29 |
-
logger.info(f"Loading model from
|
30 |
|
31 |
-
#
|
32 |
-
model =
|
33 |
-
model_path,
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
threads=os.cpu_count() or 1
|
39 |
)
|
40 |
-
logger.info("Model loaded successfully!")
|
41 |
except Exception as e:
|
42 |
logger.error(f"Failed to load model: {e}")
|
43 |
-
# Raising an exception during startup will prevent the app from starting
|
44 |
raise e
|
45 |
|
46 |
yield
|
47 |
-
#
|
48 |
logger.info("Application is shutting down.")
|
49 |
|
50 |
|
@@ -70,26 +66,28 @@ class TextResponse(BaseModel):
|
|
70 |
@app.post("/generate", response_model=TextResponse)
|
71 |
async def generate_text(request: TextRequest):
|
72 |
if model is None:
|
73 |
-
raise HTTPException(status_code=503, detail="Model is not ready or failed to load.
|
74 |
|
75 |
try:
|
|
|
76 |
if request.system_prompt:
|
77 |
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
|
78 |
else:
|
79 |
full_prompt = request.inputs
|
80 |
|
81 |
-
|
82 |
-
|
83 |
-
|
|
|
84 |
temperature=request.temperature,
|
85 |
top_p=request.top_p,
|
86 |
top_k=request.top_k,
|
87 |
-
|
88 |
stop=request.stop or []
|
89 |
)
|
90 |
|
91 |
-
|
92 |
-
|
93 |
|
94 |
return TextResponse(generated_text=generated_text)
|
95 |
|
|
|
1 |
from fastapi import FastAPI, HTTPException
|
2 |
from pydantic import BaseModel
|
3 |
+
from llama_cpp import Llama
|
4 |
import os
|
5 |
import uvicorn
|
6 |
from typing import Optional, List
|
|
|
17 |
# Lifespan manager to load the model on startup
|
18 |
@asynccontextmanager
|
19 |
async def lifespan(app: FastAPI):
|
|
|
20 |
global model
|
21 |
+
model_gguf_path = os.path.join("./model", "gema-4b-indra10k-model1-q4_k_m.gguf")
|
|
|
22 |
|
23 |
try:
|
24 |
+
if not os.path.exists(model_gguf_path):
|
25 |
+
raise RuntimeError(f"Model file not found at: {model_gguf_path}")
|
26 |
|
27 |
+
logger.info(f"Loading model from: {model_gguf_path}")
|
28 |
|
29 |
+
# Load the model using llama-cpp-python
|
30 |
+
model = Llama(
|
31 |
+
model_path=model_gguf_path,
|
32 |
+
n_ctx=2048, # Context length
|
33 |
+
n_gpu_layers=0, # Set to a positive number if GPU is available
|
34 |
+
n_threads=os.cpu_count() or 1,
|
35 |
+
verbose=True,
|
|
|
36 |
)
|
37 |
+
logger.info("Model loaded successfully using llama-cpp-python!")
|
38 |
except Exception as e:
|
39 |
logger.error(f"Failed to load model: {e}")
|
|
|
40 |
raise e
|
41 |
|
42 |
yield
|
43 |
+
# Cleanup code if needed on shutdown
|
44 |
logger.info("Application is shutting down.")
|
45 |
|
46 |
|
|
|
66 |
@app.post("/generate", response_model=TextResponse)
|
67 |
async def generate_text(request: TextRequest):
|
68 |
if model is None:
|
69 |
+
raise HTTPException(status_code=503, detail="Model is not ready or failed to load.")
|
70 |
|
71 |
try:
|
72 |
+
# Create prompt
|
73 |
if request.system_prompt:
|
74 |
full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
|
75 |
else:
|
76 |
full_prompt = request.inputs
|
77 |
|
78 |
+
# Generate text using llama-cpp-python syntax
|
79 |
+
output = model(
|
80 |
+
prompt=full_prompt,
|
81 |
+
max_tokens=request.max_tokens,
|
82 |
temperature=request.temperature,
|
83 |
top_p=request.top_p,
|
84 |
top_k=request.top_k,
|
85 |
+
repeat_penalty=request.repeat_penalty,
|
86 |
stop=request.stop or []
|
87 |
)
|
88 |
|
89 |
+
# Extract the generated text from the response structure
|
90 |
+
generated_text = output['choices'][0]['text'].strip()
|
91 |
|
92 |
return TextResponse(generated_text=generated_text)
|
93 |
|