File size: 3,577 Bytes
6d2ea02
 
 
 
 
 
 
 
1cc8acd
6d2ea02
 
 
b10a1be
6d2ea02
b10a1be
6d2ea02
 
 
8f0965e
6d2ea02
 
 
8f0965e
6d2ea02
 
 
 
 
 
 
 
 
 
 
 
b10a1be
 
 
 
6d2ea02
b10a1be
 
 
 
 
6d2ea02
b10a1be
 
6d2ea02
b10a1be
6d2ea02
b10a1be
6d2ea02
 
 
 
b10a1be
6d2ea02
 
 
 
 
b10a1be
6d2ea02
 
b10a1be
6d2ea02
 
 
 
 
b10a1be
6d2ea02
 
 
 
 
 
 
 
 
 
b10a1be
6d2ea02
 
 
 
 
 
 
 
 
 
 
b10a1be
 
6d2ea02
 
 
 
 
 
 
8f0965e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM
import os
import uvicorn
from typing import Optional, List
import logging

# Set up loggings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(title="Gema 4B Model API", version="1.0.0")

# Request model
class TextRequest(BaseModel):
    inputs: str
    system_prompt: Optional[str] = None
    max_tokens: Optional[int] = 512
    temperature: Optional[float] = 0.7
    top_k: Optional[int] = 50
    top_p: Optional[float] = 0.9
    repeat_penalty: Optional[float] = 1.1
    stop: Optional[List[str]] = None

# Response model
class TextResponse(BaseModel):
    generated_text: str

# Global model variable
model = None

@app.on_event("startup")
async def load_model():
    global model
    # Define the local model path
    model_path = "./model"
    model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
    
    try:
        if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)):
             raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")

        logger.info(f"Loading model from local path: {model_path}")
        # Load the model from the local directory downloaded during the Docker build
        model = AutoModelForCausalLM.from_pretrained(
            model_path, # Load from the local folder
            model_file=model_file, # Specify the GGUF file name
            model_type="llama",
            gpu_layers=0,
            context_length=2048,
            threads=os.cpu_count() or 1
        )
        logger.info("Model loaded successfully!")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        # Raising the exception will prevent the app from starting if the model fails to load
        raise e

@app.post("/generate", response_model=TextResponse)
async def generate_text(request: TextRequest):
    if model is None:
        raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please try again later.")
    
    try:
        # Create prompt
        if request.system_prompt:
            full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
        else:
            full_prompt = request.inputs
        
        # Generate text with parameters from the request
        generated_text = model(
            full_prompt,
            max_new_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p,
            top_k=request.top_k,
            repetition_penalty=request.repeat_penalty,
            stop=request.stop or []
        )
        
        # Clean up the response
        if "Assistant:" in generated_text:
            generated_text = generated_text.split("Assistant:")[-1].strip()
        
        return TextResponse(generated_text=generated_text)
    
    except Exception as e:
        logger.error(f"Generation error: {e}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

@app.get("/health")
async def health_check():
    # The health check now also implicitly checks if the model has been loaded
    # because a failure in load_model will stop the app from running.
    return {"status": "healthy", "model_loaded": model is not None}

@app.get("/")
async def root():
    return {"message": "Gema 4B Model API", "docs": "/docs"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")