File size: 2,406 Bytes
2289445
62d49a1
2289445
 
 
 
 
62d49a1
2289445
62d49a1
2289445
62d49a1
2289445
 
 
62d49a1
 
 
2289445
 
62d49a1
5ee2484
2289445
 
5ee2484
62d49a1
2289445
 
 
62d49a1
2289445
5ee2484
62d49a1
2289445
 
 
845f0c7
 
 
5ee2484
2289445
 
 
5ee2484
 
62d49a1
 
 
 
5ee2484
2289445
62d49a1
 
2289445
5ee2484
2289445
 
 
62d49a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2289445
62d49a1
 
2289445
5ee2484
 
62d49a1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
app.py FastAPI API for Quantized OpenChat 3.5 (GGUF) using ctransformers
"""

import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from model_loader import model
import uvicorn
from ctransformers import AutoTokenizer  # Add this at the top

# Logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

tokenizer = model.tokenize  # Use model's built-in tokenizer if available

# FastAPI app
app = FastAPI(
    title="masx-openchat-llm",
    description="MASX AI service exposing a quantized OpenChat-3.5 model (GGUF)",
    version="1.0.0",
)


# Request schema
class PromptRequest(BaseModel):
    prompt: str
    max_tokens: int = 256
    temperature: float = 0.0


# Response schema
class ChatResponse(BaseModel):
    response: str

@app.get("/")
def root():
    return {"message": "MASX OpenChat API is running"}

@app.get("/status")
async def status():
    try:
        return {
            "status": "ok",
            "model_path": getattr(model, "model_path", "unknown"),
            "model_type": getattr(model, "model_type", "unknown"),
            "context_length": getattr(model, "context_length", "unknown"),
            "gpu_layers": getattr(model, "gpu_layers", 0),
        }
    except Exception as e:
        logger.error("Status check failed: %s", str(e), exc_info=True)
        raise HTTPException(status_code=500, detail="Model status check failed")


@app.post("/chat", response_model=ChatResponse)
async def chat(req: PromptRequest):
    try:
        logger.info("Prompt: %s", req.prompt)

        prompt_tokens = model.tokenize(req.prompt)
        if len(prompt_tokens) > model.context_length:
            raise HTTPException(
                status_code=400,
                detail=f"Prompt too long ({len(prompt_tokens)} tokens). Max context: {model.context_length}",
            )

        response = model(
            req.prompt,
            max_new_tokens=req.max_tokens,
            temperature=req.temperature,
            stop=["</s>"],
        )
        logger.info("Response: %s", response)
        return ChatResponse(response=response.strip())
    except Exception as e:
        logger.error("Chat error: %s", str(e), exc_info=True)
        raise HTTPException(status_code=500, detail="Inference failure")


if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info")