Dnfs commited on
Commit
b24565b
·
verified ·
1 Parent(s): 4af78cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -32
app.py CHANGED
@@ -5,35 +5,20 @@ import os
5
  import uvicorn
6
  from typing import Optional, List
7
  import logging
 
8
 
9
- # Set up loggings
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- app = FastAPI(title="Gema 4B Model API", version="1.0.0")
14
-
15
- # Request model
16
- class TextRequest(BaseModel):
17
- inputs: str
18
- system_prompt: Optional[str] = None
19
- max_tokens: Optional[int] = 512
20
- temperature: Optional[float] = 0.7
21
- top_k: Optional[int] = 50
22
- top_p: Optional[float] = 0.9
23
- repeat_penalty: Optional[float] = 1.1
24
- stop: Optional[List[str]] = None
25
-
26
- # Response model
27
- class TextResponse(BaseModel):
28
- generated_text: str
29
-
30
  # Global model variable
31
  model = None
32
 
33
- @app.on_event("startup")
34
- async def load_model():
 
 
35
  global model
36
- # Define the local model path
37
  model_path = "./model"
38
  model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
39
 
@@ -42,11 +27,12 @@ async def load_model():
42
  raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
43
 
44
  logger.info(f"Loading model from local path: {model_path}")
45
- # Load the model from the local directory downloaded during the Docker build
 
46
  model = AutoModelForCausalLM.from_pretrained(
47
- model_path, # Load from the local folder
48
- model_file=model_file, # Specify the GGUF file name
49
- model_type="llama",
50
  gpu_layers=0,
51
  context_length=2048,
52
  threads=os.cpu_count() or 1
@@ -54,22 +40,44 @@ async def load_model():
54
  logger.info("Model loaded successfully!")
55
  except Exception as e:
56
  logger.error(f"Failed to load model: {e}")
57
- # Raising the exception will prevent the app from starting if the model fails to load
58
  raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @app.post("/generate", response_model=TextResponse)
61
  async def generate_text(request: TextRequest):
62
  if model is None:
63
- raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please try again later.")
64
 
65
  try:
66
- # Create prompt
67
  if request.system_prompt:
68
  full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
69
  else:
70
  full_prompt = request.inputs
71
 
72
- # Generate text with parameters from the request
73
  generated_text = model(
74
  full_prompt,
75
  max_new_tokens=request.max_tokens,
@@ -80,7 +88,6 @@ async def generate_text(request: TextRequest):
80
  stop=request.stop or []
81
  )
82
 
83
- # Clean up the response
84
  if "Assistant:" in generated_text:
85
  generated_text = generated_text.split("Assistant:")[-1].strip()
86
 
@@ -92,8 +99,6 @@ async def generate_text(request: TextRequest):
92
 
93
  @app.get("/health")
94
  async def health_check():
95
- # The health check now also implicitly checks if the model has been loaded
96
- # because a failure in load_model will stop the app from running.
97
  return {"status": "healthy", "model_loaded": model is not None}
98
 
99
  @app.get("/")
 
5
  import uvicorn
6
  from typing import Optional, List
7
  import logging
8
+ from contextlib import asynccontextmanager
9
 
10
+ # Set up logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # Global model variable
15
  model = None
16
 
17
+ # Lifespan manager to load the model on startup
18
+ @asynccontextmanager
19
+ async def lifespan(app: FastAPI):
20
+ # This code runs on startup
21
  global model
 
22
  model_path = "./model"
23
  model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
24
 
 
27
  raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
28
 
29
  logger.info(f"Loading model from local path: {model_path}")
30
+
31
+ # FIX: Changed model_type from "llama" to "gemma"
32
  model = AutoModelForCausalLM.from_pretrained(
33
+ model_path,
34
+ model_file=model_file,
35
+ model_type="gemma", # This was the main cause of the error
36
  gpu_layers=0,
37
  context_length=2048,
38
  threads=os.cpu_count() or 1
 
40
  logger.info("Model loaded successfully!")
41
  except Exception as e:
42
  logger.error(f"Failed to load model: {e}")
43
+ # Raising an exception during startup will prevent the app from starting
44
  raise e
45
+
46
+ yield
47
+ # This code runs on shutdown (optional)
48
+ logger.info("Application is shutting down.")
49
+
50
+
51
+ app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan)
52
+
53
+
54
+ # Request model
55
+ class TextRequest(BaseModel):
56
+ inputs: str
57
+ system_prompt: Optional[str] = None
58
+ max_tokens: Optional[int] = 512
59
+ temperature: Optional[float] = 0.7
60
+ top_k: Optional[int] = 50
61
+ top_p: Optional[float] = 0.9
62
+ repeat_penalty: Optional[float] = 1.1
63
+ stop: Optional[List[str]] = None
64
+
65
+ # Response model
66
+ class TextResponse(BaseModel):
67
+ generated_text: str
68
+
69
 
70
  @app.post("/generate", response_model=TextResponse)
71
  async def generate_text(request: TextRequest):
72
  if model is None:
73
+ raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please check logs.")
74
 
75
  try:
 
76
  if request.system_prompt:
77
  full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
78
  else:
79
  full_prompt = request.inputs
80
 
 
81
  generated_text = model(
82
  full_prompt,
83
  max_new_tokens=request.max_tokens,
 
88
  stop=request.stop or []
89
  )
90
 
 
91
  if "Assistant:" in generated_text:
92
  generated_text = generated_text.split("Assistant:")[-1].strip()
93
 
 
99
 
100
  @app.get("/health")
101
  async def health_check():
 
 
102
  return {"status": "healthy", "model_loaded": model is not None}
103
 
104
  @app.get("/")