Dnfs commited on
Commit
0504f7a
·
verified ·
1 Parent(s): b24565b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -25
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from ctransformers import AutoModelForCausalLM
4
  import os
5
  import uvicorn
6
  from typing import Optional, List
@@ -17,34 +17,30 @@ model = None
17
  # Lifespan manager to load the model on startup
18
  @asynccontextmanager
19
  async def lifespan(app: FastAPI):
20
- # This code runs on startup
21
  global model
22
- model_path = "./model"
23
- model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
24
 
25
  try:
26
- if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)):
27
- raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")
28
 
29
- logger.info(f"Loading model from local path: {model_path}")
30
 
31
- # FIX: Changed model_type from "llama" to "gemma"
32
- model = AutoModelForCausalLM.from_pretrained(
33
- model_path,
34
- model_file=model_file,
35
- model_type="gemma", # This was the main cause of the error
36
- gpu_layers=0,
37
- context_length=2048,
38
- threads=os.cpu_count() or 1
39
  )
40
- logger.info("Model loaded successfully!")
41
  except Exception as e:
42
  logger.error(f"Failed to load model: {e}")
43
- # Raising an exception during startup will prevent the app from starting
44
  raise e
45
 
46
  yield
47
- # This code runs on shutdown (optional)
48
  logger.info("Application is shutting down.")
49
 
50
 
@@ -70,26 +66,28 @@ class TextResponse(BaseModel):
70
  @app.post("/generate", response_model=TextResponse)
71
  async def generate_text(request: TextRequest):
72
  if model is None:
73
- raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please check logs.")
74
 
75
  try:
 
76
  if request.system_prompt:
77
  full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
78
  else:
79
  full_prompt = request.inputs
80
 
81
- generated_text = model(
82
- full_prompt,
83
- max_new_tokens=request.max_tokens,
 
84
  temperature=request.temperature,
85
  top_p=request.top_p,
86
  top_k=request.top_k,
87
- repetition_penalty=request.repeat_penalty,
88
  stop=request.stop or []
89
  )
90
 
91
- if "Assistant:" in generated_text:
92
- generated_text = generated_text.split("Assistant:")[-1].strip()
93
 
94
  return TextResponse(generated_text=generated_text)
95
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from llama_cpp import Llama
4
  import os
5
  import uvicorn
6
  from typing import Optional, List
 
17
  # Lifespan manager to load the model on startup
18
  @asynccontextmanager
19
  async def lifespan(app: FastAPI):
 
20
  global model
21
+ model_gguf_path = os.path.join("./model", "gema-4b-indra10k-model1-q4_k_m.gguf")
 
22
 
23
  try:
24
+ if not os.path.exists(model_gguf_path):
25
+ raise RuntimeError(f"Model file not found at: {model_gguf_path}")
26
 
27
+ logger.info(f"Loading model from: {model_gguf_path}")
28
 
29
+ # Load the model using llama-cpp-python
30
+ model = Llama(
31
+ model_path=model_gguf_path,
32
+ n_ctx=2048, # Context length
33
+ n_gpu_layers=0, # Set to a positive number if GPU is available
34
+ n_threads=os.cpu_count() or 1,
35
+ verbose=True,
 
36
  )
37
+ logger.info("Model loaded successfully using llama-cpp-python!")
38
  except Exception as e:
39
  logger.error(f"Failed to load model: {e}")
 
40
  raise e
41
 
42
  yield
43
+ # Cleanup code if needed on shutdown
44
  logger.info("Application is shutting down.")
45
 
46
 
 
66
  @app.post("/generate", response_model=TextResponse)
67
  async def generate_text(request: TextRequest):
68
  if model is None:
69
+ raise HTTPException(status_code=503, detail="Model is not ready or failed to load.")
70
 
71
  try:
72
+ # Create prompt
73
  if request.system_prompt:
74
  full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
75
  else:
76
  full_prompt = request.inputs
77
 
78
+ # Generate text using llama-cpp-python syntax
79
+ output = model(
80
+ prompt=full_prompt,
81
+ max_tokens=request.max_tokens,
82
  temperature=request.temperature,
83
  top_p=request.top_p,
84
  top_k=request.top_k,
85
+ repeat_penalty=request.repeat_penalty,
86
  stop=request.stop or []
87
  )
88
 
89
+ # Extract the generated text from the response structure
90
+ generated_text = output['choices'][0]['text'].strip()
91
 
92
  return TextResponse(generated_text=generated_text)
93