omaryasserhassan commited on
Commit
f5558db
·
verified ·
1 Parent(s): 9a4b46f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -10
app.py CHANGED
@@ -5,14 +5,20 @@ from typing import List, Optional
5
  from huggingface_hub import hf_hub_download
6
  from ctransformers import AutoModelForCausalLM
7
 
8
- REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
9
- FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf" # <- safer quant
 
10
  MODEL_TYPE = "llama"
11
 
12
-
13
  CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
14
  os.makedirs(CACHE_DIR, exist_ok=True)
15
 
 
 
 
 
 
 
16
  app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
17
  _model = None
18
 
@@ -32,12 +38,11 @@ def get_model():
32
  model_path_or_repo_id=os.path.dirname(local_path),
33
  model_file=os.path.basename(local_path),
34
  model_type=MODEL_TYPE,
35
- context_length=int(os.environ.get("CTX_LEN", "4096")),
36
- CTX_LEN = int(os.environ.get("CTX_LEN", "1024"))
37
- BATCH = int(os.environ.get("BATCH", "16"))
38
- THREADS = int(os.environ.get("THREADS", "4"))
39
- GPU_LAY = int(os.environ.get("GPU_LAYERS","0"))
40
-
41
  )
42
  return _model
43
 
@@ -55,7 +60,17 @@ class GenerateOut(BaseModel):
55
 
56
  @app.get("/")
57
  def health():
58
- return {"status": "ok", "cache_dir": CACHE_DIR}
 
 
 
 
 
 
 
 
 
 
59
 
60
  @app.post("/generate", response_model=GenerateOut)
61
  def generate(body: GenerateIn):
@@ -72,4 +87,7 @@ def generate(body: GenerateIn):
72
  )
73
  return GenerateOut(completion=text)
74
  except Exception as e:
 
 
 
75
  raise HTTPException(status_code=500, detail=str(e))
 
5
  from huggingface_hub import hf_hub_download
6
  from ctransformers import AutoModelForCausalLM
7
 
8
+ # --- Model & cache config ---
9
+ REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF"
10
+ FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf" # safer quant than Q4_K_L
11
  MODEL_TYPE = "llama"
12
 
 
13
  CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache")
14
  os.makedirs(CACHE_DIR, exist_ok=True)
15
 
16
+ # Conservative defaults (override via env if needed)
17
+ CTX_LEN = int(os.environ.get("CTX_LEN", "1024"))
18
+ BATCH_SIZE = int(os.environ.get("BATCH", "16"))
19
+ THREADS = int(os.environ.get("THREADS", "4"))
20
+ GPU_LAYERS = int(os.environ.get("GPU_LAYERS", "0"))
21
+
22
  app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)")
23
  _model = None
24
 
 
38
  model_path_or_repo_id=os.path.dirname(local_path),
39
  model_file=os.path.basename(local_path),
40
  model_type=MODEL_TYPE,
41
+ context_length=CTX_LEN,
42
+ batch_size=BATCH_SIZE,
43
+ threads=THREADS,
44
+ gpu_layers=GPU_LAYERS,
45
+ f16_kv=True,
 
46
  )
47
  return _model
48
 
 
60
 
61
  @app.get("/")
62
  def health():
63
+ return {
64
+ "status": "ok",
65
+ "cache_dir": CACHE_DIR,
66
+ "model": {"repo": REPO_ID, "file": FILENAME, "type": MODEL_TYPE},
67
+ "settings": {
68
+ "CTX_LEN": CTX_LEN,
69
+ "BATCH_SIZE": BATCH_SIZE,
70
+ "THREADS": THREADS,
71
+ "GPU_LAYERS": GPU_LAYERS,
72
+ },
73
+ }
74
 
75
  @app.post("/generate", response_model=GenerateOut)
76
  def generate(body: GenerateIn):
 
87
  )
88
  return GenerateOut(completion=text)
89
  except Exception as e:
90
+ # helpful for debugging in Space logs
91
+ import sys, traceback
92
+ traceback.print_exc(file=sys.stderr)
93
  raise HTTPException(status_code=500, detail=str(e))