Dnfs commited on
Commit
b10a1be
·
verified ·
1 Parent(s): 3080fd5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -11
app.py CHANGED
@@ -10,9 +10,9 @@ import logging
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
- app = FastAPI(title="llm-apiku", version="1.0.0")
14
 
15
- # Request model - fleksibel untuk menerima semua parameter
16
  class TextRequest(BaseModel):
17
  inputs: str
18
  system_prompt: Optional[str] = None
@@ -33,34 +33,43 @@ model = None
33
  @app.on_event("startup")
34
  async def load_model():
35
  global model
 
 
 
 
36
  try:
37
- logger.info("Loading model...")
 
 
 
 
38
  model = AutoModelForCausalLM.from_pretrained(
39
- "Dnfs/gema-4b-indra10k-model1-Q4_K_M-GGUF",
40
- model_file="gema-4b-indra10k-model1-q4_k_m.gguf",
41
  model_type="llama",
42
- gpu_layers=0, # Set to appropriate number if using GPU
43
  context_length=2048,
44
- threads=os.cpu_count()
45
  )
46
  logger.info("Model loaded successfully!")
47
  except Exception as e:
48
  logger.error(f"Failed to load model: {e}")
 
49
  raise e
50
 
51
  @app.post("/generate", response_model=TextResponse)
52
  async def generate_text(request: TextRequest):
53
  if model is None:
54
- raise HTTPException(status_code=500, detail="Model not loaded")
55
 
56
  try:
57
- # Buat prompt - gunakan system_prompt jika ada, atau langsung input user
58
  if request.system_prompt:
59
  full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
60
  else:
61
  full_prompt = request.inputs
62
 
63
- # Generate text dengan parameter dari request
64
  generated_text = model(
65
  full_prompt,
66
  max_new_tokens=request.max_tokens,
@@ -71,7 +80,7 @@ async def generate_text(request: TextRequest):
71
  stop=request.stop or []
72
  )
73
 
74
- # Bersihkan response dari system prompt jika ada
75
  if "Assistant:" in generated_text:
76
  generated_text = generated_text.split("Assistant:")[-1].strip()
77
 
@@ -83,6 +92,8 @@ async def generate_text(request: TextRequest):
83
 
84
  @app.get("/health")
85
  async def health_check():
 
 
86
  return {"status": "healthy", "model_loaded": model is not None}
87
 
88
  @app.get("/")
 
10
  logging.basicConfig(level=logging.INFO)
11
  logger = logging.getLogger(__name__)
12
 
13
+ app = FastAPI(title="Gema 4B Model API", version="1.0.0")
14
 
15
+ # Request model
16
  class TextRequest(BaseModel):
17
  inputs: str
18
  system_prompt: Optional[str] = None
 
33
  @app.on_event("startup")
34
  async def load_model():
35
  global model
36
+ # Define the local model path
37
+ model_path = "./model"
38
+ model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
39
+
40
  try:
41
+ if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)):
42
+ raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
43
+
44
+ logger.info(f"Loading model from local path: {model_path}")
45
+ # Load the model from the local directory downloaded during the Docker build
46
  model = AutoModelForCausalLM.from_pretrained(
47
+ model_path, # Load from the local folder
48
+ model_file=model_file, # Specify the GGUF file name
49
  model_type="llama",
50
+ gpu_layers=0,
51
  context_length=2048,
52
+ threads=os.cpu_count() or 1
53
  )
54
  logger.info("Model loaded successfully!")
55
  except Exception as e:
56
  logger.error(f"Failed to load model: {e}")
57
+ # Raising the exception will prevent the app from starting if the model fails to load
58
  raise e
59
 
60
  @app.post("/generate", response_model=TextResponse)
61
  async def generate_text(request: TextRequest):
62
  if model is None:
63
+ raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please try again later.")
64
 
65
  try:
66
+ # Create prompt
67
  if request.system_prompt:
68
  full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
69
  else:
70
  full_prompt = request.inputs
71
 
72
+ # Generate text with parameters from the request
73
  generated_text = model(
74
  full_prompt,
75
  max_new_tokens=request.max_tokens,
 
80
  stop=request.stop or []
81
  )
82
 
83
+ # Clean up the response
84
  if "Assistant:" in generated_text:
85
  generated_text = generated_text.split("Assistant:")[-1].strip()
86
 
 
92
 
93
  @app.get("/health")
94
  async def health_check():
95
+ # The health check now also implicitly checks if the model has been loaded
96
+ # because a failure in load_model will stop the app from running.
97
  return {"status": "healthy", "model_loaded": model is not None}
98
 
99
  @app.get("/")