File size: 3,485 Bytes
6d2ea02
 
 
 
 
 
 
b24565b
6d2ea02
b24565b
6d2ea02
 
 
 
 
 
b24565b
 
 
 
6d2ea02
b10a1be
 
 
6d2ea02
b10a1be
 
 
 
b24565b
 
6d2ea02
b24565b
 
 
b10a1be
6d2ea02
b10a1be
6d2ea02
 
 
 
b24565b
6d2ea02
b24565b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d2ea02
 
 
 
b24565b
6d2ea02
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f0965e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from ctransformers import AutoModelForCausalLM
import os
import uvicorn
from typing import Optional, List
import logging
from contextlib import asynccontextmanager

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global model variable
model = None

# Lifespan manager to load the model on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
    # This code runs on startup
    global model
    model_path = "./model"
    model_file = "gema-4b-indra10k-model1-q4_k_m.gguf"
    
    try:
        if not os.path.exists(model_path) or not os.path.exists(os.path.join(model_path, model_file)):
             raise RuntimeError("Model files not found. Ensure the model was downloaded in the Docker build.")

        logger.info(f"Loading model from local path: {model_path}")
        
        # FIX: Changed model_type from "llama" to "gemma"
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            model_file=model_file,
            model_type="gemma",  # This was the main cause of the error
            gpu_layers=0,
            context_length=2048,
            threads=os.cpu_count() or 1
        )
        logger.info("Model loaded successfully!")
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        # Raising an exception during startup will prevent the app from starting
        raise e
    
    yield
    # This code runs on shutdown (optional)
    logger.info("Application is shutting down.")


app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan)


# Request model
class TextRequest(BaseModel):
    inputs: str
    system_prompt: Optional[str] = None
    max_tokens: Optional[int] = 512
    temperature: Optional[float] = 0.7
    top_k: Optional[int] = 50
    top_p: Optional[float] = 0.9
    repeat_penalty: Optional[float] = 1.1
    stop: Optional[List[str]] = None

# Response model
class TextResponse(BaseModel):
    generated_text: str


@app.post("/generate", response_model=TextResponse)
async def generate_text(request: TextRequest):
    if model is None:
        raise HTTPException(status_code=503, detail="Model is not ready or failed to load. Please check logs.")
    
    try:
        if request.system_prompt:
            full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:"
        else:
            full_prompt = request.inputs
        
        generated_text = model(
            full_prompt,
            max_new_tokens=request.max_tokens,
            temperature=request.temperature,
            top_p=request.top_p,
            top_k=request.top_k,
            repetition_penalty=request.repeat_penalty,
            stop=request.stop or []
        )
        
        if "Assistant:" in generated_text:
            generated_text = generated_text.split("Assistant:")[-1].strip()
        
        return TextResponse(generated_text=generated_text)
    
    except Exception as e:
        logger.error(f"Generation error: {e}")
        raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")

@app.get("/health")
async def health_check():
    return {"status": "healthy", "model_loaded": model is not None}

@app.get("/")
async def root():
    return {"message": "Gema 4B Model API", "docs": "/docs"}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")