Spaces:
Sleeping
Sleeping
ateetvatan
commited on
Commit
·
5ee2484
1
Parent(s):
24d6d7e
formatting
Browse files- app.py +16 -7
- model_loader.py +1 -1
app.py
CHANGED
@@ -17,29 +17,37 @@ logging.basicConfig(level=logging.INFO)
|
|
17 |
app = FastAPI(
|
18 |
title="masx-openchat-llm",
|
19 |
description="MASX AI service exposing the OpenChat-3.5 LLM as an inference endpoint",
|
20 |
-
version="1.0.0"
|
21 |
)
|
22 |
|
|
|
23 |
# Request ********schema*******
|
24 |
class PromptRequest(BaseModel):
|
25 |
prompt: str
|
26 |
max_tokens: int = 256
|
27 |
temperature: float = 0.0 # Deterministic by default
|
28 |
|
|
|
29 |
# Response ********schema*******
|
30 |
class ChatResponse(BaseModel):
|
31 |
response: str
|
32 |
|
|
|
33 |
@app.get("/status")
|
34 |
async def status():
|
35 |
"""Check model status and max supported tokens."""
|
36 |
try:
|
37 |
max_context = getattr(model.config, "max_position_embeddings", "unknown")
|
38 |
-
return {
|
|
|
|
|
|
|
|
|
39 |
except Exception as e:
|
40 |
logger.error("Status error: %s", str(e))
|
41 |
raise HTTPException(status_code=500, detail=str(e))
|
42 |
|
|
|
43 |
@app.post("/chat", response_model=ChatResponse)
|
44 |
async def chat(req: PromptRequest):
|
45 |
"""OpenChat-3.5 Run inference prompt"""
|
@@ -49,12 +57,12 @@ async def chat(req: PromptRequest):
|
|
49 |
# Dynamically choose device at request time
|
50 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
51 |
logger.info(f"Using device: {device}")
|
52 |
-
|
53 |
-
# Move model to device if not
|
54 |
if next(model.parameters()).device != device:
|
55 |
logger.info("Moving model to %s", device)
|
56 |
model.to(device)
|
57 |
-
|
58 |
# Tokenize input
|
59 |
inputs = tokenizer(req.prompt, return_tensors="pt").to(device)
|
60 |
|
@@ -70,7 +78,7 @@ async def chat(req: PromptRequest):
|
|
70 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
71 |
|
72 |
# Trim echoed prompt if present
|
73 |
-
response_text = generated_text[len(req.prompt):].strip()
|
74 |
|
75 |
logger.info("Generated response: %s", response_text)
|
76 |
return ChatResponse(response=response_text)
|
@@ -79,5 +87,6 @@ async def chat(req: PromptRequest):
|
|
79 |
logger.error("Inference failed: %s", str(e), exc_info=True)
|
80 |
raise HTTPException(status_code=500, detail="Inference failure: " + str(e))
|
81 |
|
82 |
-
|
|
|
83 |
uvicorn.run("app:app", host="0.0.0.0", port=8080, log_level="info")
|
|
|
17 |
app = FastAPI(
|
18 |
title="masx-openchat-llm",
|
19 |
description="MASX AI service exposing the OpenChat-3.5 LLM as an inference endpoint",
|
20 |
+
version="1.0.0",
|
21 |
)
|
22 |
|
23 |
+
|
24 |
# Request ********schema*******
|
25 |
class PromptRequest(BaseModel):
|
26 |
prompt: str
|
27 |
max_tokens: int = 256
|
28 |
temperature: float = 0.0 # Deterministic by default
|
29 |
|
30 |
+
|
31 |
# Response ********schema*******
|
32 |
class ChatResponse(BaseModel):
|
33 |
response: str
|
34 |
|
35 |
+
|
36 |
@app.get("/status")
|
37 |
async def status():
|
38 |
"""Check model status and max supported tokens."""
|
39 |
try:
|
40 |
max_context = getattr(model.config, "max_position_embeddings", "unknown")
|
41 |
+
return {
|
42 |
+
"status": "ok",
|
43 |
+
"model": model.name_or_path,
|
44 |
+
"max_context_tokens": max_context,
|
45 |
+
}
|
46 |
except Exception as e:
|
47 |
logger.error("Status error: %s", str(e))
|
48 |
raise HTTPException(status_code=500, detail=str(e))
|
49 |
|
50 |
+
|
51 |
@app.post("/chat", response_model=ChatResponse)
|
52 |
async def chat(req: PromptRequest):
|
53 |
"""OpenChat-3.5 Run inference prompt"""
|
|
|
57 |
# Dynamically choose device at request time
|
58 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
logger.info(f"Using device: {device}")
|
60 |
+
|
61 |
+
# Move model to device if not
|
62 |
if next(model.parameters()).device != device:
|
63 |
logger.info("Moving model to %s", device)
|
64 |
model.to(device)
|
65 |
+
|
66 |
# Tokenize input
|
67 |
inputs = tokenizer(req.prompt, return_tensors="pt").to(device)
|
68 |
|
|
|
78 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
79 |
|
80 |
# Trim echoed prompt if present
|
81 |
+
response_text = generated_text[len(req.prompt) :].strip()
|
82 |
|
83 |
logger.info("Generated response: %s", response_text)
|
84 |
return ChatResponse(response=response_text)
|
|
|
87 |
logger.error("Inference failed: %s", str(e), exc_info=True)
|
88 |
raise HTTPException(status_code=500, detail="Inference failure: " + str(e))
|
89 |
|
90 |
+
|
91 |
+
if __name__ == "__main__":
|
92 |
uvicorn.run("app:app", host="0.0.0.0", port=8080, log_level="info")
|
model_loader.py
CHANGED
@@ -11,4 +11,4 @@ MODEL_NAME = os.getenv("MODEL_NAME", "openchat/openchat-3.5-1210")
|
|
11 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
12 |
|
13 |
# Load model initially on CPU
|
14 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to("cpu")
|
|
|
11 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
12 |
|
13 |
# Load model initially on CPU
|
14 |
+
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to("cpu")
|