File size: 5,251 Bytes
7e32345
9f40e8d
 
 
7e32345
 
 
 
 
9f40e8d
7e32345
 
 
b173427
9f40e8d
b173427
7e32345
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b173427
470d3ad
9f40e8d
470d3ad
9f40e8d
470d3ad
9f40e8d
 
 
 
 
b173427
9f40e8d
 
b173427
 
 
 
9f40e8d
470d3ad
9f40e8d
b173427
9f40e8d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import time
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from llama_cpp import Llama
from huggingface_hub import login, hf_hub_download
import logging
import os
import asyncio
import psutil  # Added for RAM tracking

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Global lock for model access
model_lock = asyncio.Lock()

# Authenticate with Hugging Face
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
    logger.error("HF_TOKEN environment variable not set.")
    raise ValueError("HF_TOKEN not set")
login(token=hf_token)

# Models Configuration
repo_id = "unsloth/Qwen3-1.7B-GGUF" # "bartowski/deepcogito_cogito-v1-preview-llama-3B-GGUF" # "bartowski/deepcogito_cogito-v1-preview-llama-8B-GGUF"
filename = "Qwen3-1.7B-Q4_K_M.gguf" # "deepcogito_cogito-v1-preview-llama-3B-Q4_K_M.gguf"


try:
    # Load the model with optimized parameters
    logger.info(f"Loading {filename} model")
    model_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        local_dir="/app/cache" if os.getenv("HF_HOME") else None,
        token=hf_token,
    )
    llm = Llama(
        model_path=model_path,
        n_ctx=3072,
        n_threads=2,
        n_batch=64,
        n_gpu_layers=0,
        use_mlock=True,
        f16_kv=True,
        verbose=True,
        batch_prefill=True,
        prefill_logits=False,
    )
    logger.info(f"{filename} model loaded")

except Exception as e:
    logger.error(f"Startup error: {str(e)}", exc_info=True)
    raise


# RAM Usage Tracking Function
def get_ram_usage():
    memory = psutil.virtual_memory()
    total_ram = memory.total / (1024 ** 3)  # Convert to GB
    used_ram = memory.used / (1024 ** 3)   # Convert to GB
    free_ram = memory.available / (1024 ** 3)  # Convert to GB
    percent_used = memory.percent
    return {
        "total_ram_gb": round(total_ram, 2),
        "used_ram_gb": round(used_ram, 2),
        "free_ram_gb": round(free_ram, 2),
        "percent_used": percent_used
    }

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.get("/model_info")
async def model_info():
    return {
        "model_name": repo_id,
        "model_size": "1.7B",
        "quantization": "Q4_K_M",
    }

@app.get("/ram_usage")
async def ram_usage():
    """Endpoint to get current RAM usage."""
    try:
        ram_stats = get_ram_usage()
        return ram_stats
    except Exception as e:
        logger.error(f"Error retrieving RAM usage: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error retrieving RAM usage: {str(e)}")

# @app.on_event("startup")
# async def warm_up_model():
#     logger.info("Warming up the model...")
#     dummy_query = "Hello"
#     dummy_history = []
#     async for _ in stream_response(dummy_query, dummy_history):
#         pass
#     logger.info("Model warm-up completed.")
#     # Log initial RAM usage
#     ram_stats = get_ram_usage()
#     logger.info(f"Initial RAM usage after startup: {ram_stats}")

# Add a background task to keep the model warm
@app.on_event("startup")
async def setup_periodic_tasks():
    asyncio.create_task(keep_model_warm())
    logger.info("Periodic model warm-up task scheduled")

async def keep_model_warm():
    """Background task that keeps the model warm by sending periodic requests"""
    while True:
        try:
            logger.info("Performing periodic model warm-up")
            dummy_query = "Say only the word 'ok.'"
            dummy_history = []
            # Process a dummy query through the generator to keep it warm
            resp = llm.create_chat_completion(
                messages=[{"role": "user", "content": dummy_query}],
                max_tokens=1,
                temperature=0.0,
                top_p=1.0,
                stream=False,
            )
            logger.info("Periodic warm-up completed")
        except Exception as e:
            logger.error(f"Error in periodic warm-up: {str(e)}")
        
        # Wait for 13 minutes before the next warm-up
        await asyncio.sleep(13 * 60)

# ─── OpenAI‐compatible endpoint ─────────────────────────────────────────────
@app.post("/v1/chat/completions")
async def chat(req: dict):
    if req.get("model") != "llama-cpp":
        raise HTTPException(404, "Model not found")
    resp = llm.create_chat_completion(
        messages=req["messages"],
        max_tokens=req.get("max_tokens", 256),
        temperature=req.get("temperature", 0.7),
        top_p=req.get("top_p", 1.0),
        stream=False,
    )
    return JSONResponse({
        "id":       resp["id"],
        "object":   "chat.completion",
        "created":  resp.get("created", int(time.time())),
        "model":    "llama-cpp",
        "choices": [{
            "index": 0,
            "message": {
                "role":    resp["choices"][0]["message"]["role"],
                "content": resp["choices"][0]["message"]["content"],
            },
            "finish_reason": resp["choices"][0].get("finish_reason", "stop"),
        }],
        "usage": resp.get("usage", {}),
    })