import os from fastapi import FastAPI, HTTPException from pydantic import BaseModel from typing import List, Optional from huggingface_hub import hf_hub_download from ctransformers import AutoModelForCausalLM # --- Model & cache config --- REPO_ID = "bartowski/Llama-3.2-3B-Instruct-GGUF" FILENAME = "Llama-3.2-3B-Instruct-Q4_K_M.gguf" # safer quant than Q4_K_L MODEL_TYPE = "llama" CACHE_DIR = os.environ.get("HUGGINGFACE_HUB_CACHE", "/data/hf_cache") os.makedirs(CACHE_DIR, exist_ok=True) # Conservative defaults (override via env if needed) CTX_LEN = int(os.environ.get("CTX_LEN", "1024")) BATCH_SIZE = int(os.environ.get("BATCH", "16")) THREADS = int(os.environ.get("THREADS", "4")) GPU_LAYERS = int(os.environ.get("GPU_LAYERS", "0")) app = FastAPI(title="Llama 3.2 3B Instruct (ctransformers)") _model = None def get_model(): global _model if _model is not None: return _model local_path = hf_hub_download( repo_id=REPO_ID, filename=FILENAME, local_dir=CACHE_DIR, local_dir_use_symlinks=False, ) _model = AutoModelForCausalLM.from_pretrained( model_path_or_repo_id=os.path.dirname(local_path), model_file=os.path.basename(local_path), model_type=MODEL_TYPE, context_length=CTX_LEN, batch_size=BATCH_SIZE, threads=THREADS, gpu_layers=GPU_LAYERS, f16_kv=True, ) return _model class GenerateIn(BaseModel): prompt: str max_new_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.95 top_k: int = 40 repetition_penalty: float = 1.1 stop: Optional[List[str]] = None class GenerateOut(BaseModel): completion: str @app.get("/") def health(): return { "status": "ok", "cache_dir": CACHE_DIR, "model": {"repo": REPO_ID, "file": FILENAME, "type": MODEL_TYPE}, "settings": { "CTX_LEN": CTX_LEN, "BATCH_SIZE": BATCH_SIZE, "THREADS": THREADS, "GPU_LAYERS": GPU_LAYERS, }, } @app.post("/generate", response_model=GenerateOut) def generate(body: GenerateIn): try: model = get_model() text = model( body.prompt, max_new_tokens=body.max_new_tokens, temperature=body.temperature, top_p=body.top_p, top_k=body.top_k, repetition_penalty=body.repetition_penalty, stop=body.stop, ) return GenerateOut(completion=text) except Exception as e: # helpful for debugging in Space logs import sys, traceback traceback.print_exc(file=sys.stderr) raise HTTPException(status_code=500, detail=str(e))