File size: 3,309 Bytes
6d2ea02
 
0504f7a
6d2ea02
 
 
 
b24565b
6d2ea02
b24565b
6d2ea02
 
 
 
 
 
b24565b
 
 
6d2ea02
0504f7a
b10a1be
6d2ea02
0504f7a
 
b10a1be
0504f7a
b24565b
0504f7a
 
 
 
 
 
 
6d2ea02
0504f7a
6d2ea02
 
 
b24565b
 
0504f7a
b24565b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d2ea02
 
 
 
0504f7a
6d2ea02
 
0504f7a
6d2ea02
 
 
 
 
0504f7a
 
 
 
6d2ea02
 
 
0504f7a
6d2ea02
 
 
0504f7a
 
6d2ea02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f0965e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from llama_cpp import Llama
import os
import uvicorn
from typing import Optional, List
import logging
from contextlib import asynccontextmanager

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global model variable
model = None

# Lifespan manager to load the model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
    global model
    model_gguf_path = os.path.join("./model", "gema-4b-indra10k-model1-q4_k_m.gguf")
    
    try:
        if not os.path.exists(model_gguf_path):
            raise RuntimeError(f"Model file not found at: {model_gguf_path}")

        logger.info(f"Loading model from: {model_gguf_path}")
        
        # Load the model using llama-cpp-python
        model = Llama(
            model_path=model_gguf_path,
            n_ctx=2048,           # Context length
            n_gpu_layers=0,      # Set to a positive number if GPU is available
            n_threads=os.cpu_count() or 1,
            verbose=True,
        )
        logger.info("Model loaded successfully using llama-cpp-python!")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise e
    
    yield
    # Cleanup code if needed on shutdown
    logger.info("Application is shutting down.")


app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan)


# Request model
class TextRequest(BaseModel):
    inputs: str
    system_prompt: Optional[str] = None
    max_tokens: Optional[int] = 512
    temperature: Optional[float] = 0.7
    top_k: Optional[int] = 50
    top_p: Optional[float] = 0.9
    repeat_penalty: Optional[float] = 1.1
    stop: Optional[List[str]] = None

# Response model
class TextResponse(BaseModel):
    generated_text: str


@app.post("/generate", response_model=TextResponse)
async def generate_text(request: TextRequest):
    if model is None:
        raise HTTPException(status_code=503, detail="Model is not ready or failed to load.")
    
    try:
        # Create prompt
        if request.system_prompt:
            full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
        else:
            full_prompt = request.inputs
        
        # Generate text using llama-cpp-python syntax
        output = model(
            prompt=full_prompt,
            max_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p,
            top_k=request.top_k,
            repeat_penalty=request.repeat_penalty,
            stop=request.stop or []
        )
        
        # Extract the generated text from the response structure
        generated_text = output['choices'][0]['text'].strip()
        
        return TextResponse(generated_text=generated_text)
    
    except Exception as e:
        logger.error(f"Generation error: {e}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": model is not None}

@app.get("/")
async def root():
    return {"message": "Gema 4B Model API", "docs": "/docs"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")