ateetvatan commited on
Commit
5ee2484
·
1 Parent(s): 24d6d7e

formatting

Browse files
Files changed (2) hide show
  1. app.py +16 -7
  2. model_loader.py +1 -1
app.py CHANGED
@@ -17,29 +17,37 @@ logging.basicConfig(level=logging.INFO)
17
  app = FastAPI(
18
  title="masx-openchat-llm",
19
  description="MASX AI service exposing the OpenChat-3.5 LLM as an inference endpoint",
20
- version="1.0.0"
21
  )
22
 
 
23
  # Request ********schema*******
24
  class PromptRequest(BaseModel):
25
  prompt: str
26
  max_tokens: int = 256
27
  temperature: float = 0.0 # Deterministic by default
28
 
 
29
  # Response ********schema*******
30
  class ChatResponse(BaseModel):
31
  response: str
32
 
 
33
  @app.get("/status")
34
  async def status():
35
  """Check model status and max supported tokens."""
36
  try:
37
  max_context = getattr(model.config, "max_position_embeddings", "unknown")
38
- return {"status": "ok", "model": model.name_or_path, "max_context_tokens": max_context}
 
 
 
 
39
  except Exception as e:
40
  logger.error("Status error: %s", str(e))
41
  raise HTTPException(status_code=500, detail=str(e))
42
 
 
43
  @app.post("/chat", response_model=ChatResponse)
44
  async def chat(req: PromptRequest):
45
  """OpenChat-3.5 Run inference prompt"""
@@ -49,12 +57,12 @@ async def chat(req: PromptRequest):
49
  # Dynamically choose device at request time
50
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
  logger.info(f"Using device: {device}")
52
-
53
- # Move model to device if not
54
  if next(model.parameters()).device != device:
55
  logger.info("Moving model to %s", device)
56
  model.to(device)
57
-
58
  # Tokenize input
59
  inputs = tokenizer(req.prompt, return_tensors="pt").to(device)
60
 
@@ -70,7 +78,7 @@ async def chat(req: PromptRequest):
70
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
 
72
  # Trim echoed prompt if present
73
- response_text = generated_text[len(req.prompt):].strip()
74
 
75
  logger.info("Generated response: %s", response_text)
76
  return ChatResponse(response=response_text)
@@ -79,5 +87,6 @@ async def chat(req: PromptRequest):
79
  logger.error("Inference failed: %s", str(e), exc_info=True)
80
  raise HTTPException(status_code=500, detail="Inference failure: " + str(e))
81
 
82
- if __name__ == "__main__":
 
83
  uvicorn.run("app:app", host="0.0.0.0", port=8080, log_level="info")
 
17
  app = FastAPI(
18
  title="masx-openchat-llm",
19
  description="MASX AI service exposing the OpenChat-3.5 LLM as an inference endpoint",
20
+ version="1.0.0",
21
  )
22
 
23
+
24
  # Request ********schema*******
25
  class PromptRequest(BaseModel):
26
  prompt: str
27
  max_tokens: int = 256
28
  temperature: float = 0.0 # Deterministic by default
29
 
30
+
31
  # Response ********schema*******
32
  class ChatResponse(BaseModel):
33
  response: str
34
 
35
+
36
  @app.get("/status")
37
  async def status():
38
  """Check model status and max supported tokens."""
39
  try:
40
  max_context = getattr(model.config, "max_position_embeddings", "unknown")
41
+ return {
42
+ "status": "ok",
43
+ "model": model.name_or_path,
44
+ "max_context_tokens": max_context,
45
+ }
46
  except Exception as e:
47
  logger.error("Status error: %s", str(e))
48
  raise HTTPException(status_code=500, detail=str(e))
49
 
50
+
51
  @app.post("/chat", response_model=ChatResponse)
52
  async def chat(req: PromptRequest):
53
  """OpenChat-3.5 Run inference prompt"""
 
57
  # Dynamically choose device at request time
58
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
  logger.info(f"Using device: {device}")
60
+
61
+ # Move model to device if not
62
  if next(model.parameters()).device != device:
63
  logger.info("Moving model to %s", device)
64
  model.to(device)
65
+
66
  # Tokenize input
67
  inputs = tokenizer(req.prompt, return_tensors="pt").to(device)
68
 
 
78
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
79
 
80
  # Trim echoed prompt if present
81
+ response_text = generated_text[len(req.prompt) :].strip()
82
 
83
  logger.info("Generated response: %s", response_text)
84
  return ChatResponse(response=response_text)
 
87
  logger.error("Inference failed: %s", str(e), exc_info=True)
88
  raise HTTPException(status_code=500, detail="Inference failure: " + str(e))
89
 
90
+
91
+ if __name__ == "__main__":
92
  uvicorn.run("app:app", host="0.0.0.0", port=8080, log_level="info")
model_loader.py CHANGED
@@ -11,4 +11,4 @@ MODEL_NAME = os.getenv("MODEL_NAME", "openchat/openchat-3.5-1210")
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
 
13
  # Load model initially on CPU
14
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to("cpu")
 
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
 
13
  # Load model initially on CPU
14
+ model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to("cpu")