Lyon28 commited on
Commit
621ada9
·
verified ·
1 Parent(s): f4764ff

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +70 -19
main.py CHANGED
@@ -3,17 +3,25 @@ from pydantic import BaseModel
3
  from transformers import pipeline
4
  import torch
5
  from fastapi.middleware.cors import CORSMiddleware
 
6
 
7
- app = FastAPI(title="Model Inference API")
 
 
 
 
 
8
 
9
- # Allow CORS for external frontend
10
  app.add_middleware(
11
  CORSMiddleware,
12
  allow_origins=["*"],
 
13
  allow_methods=["*"],
14
  allow_headers=["*"],
15
  )
16
 
 
17
  MODEL_MAP = {
18
  "tinny-llama": "Lyon28/Tinny-Llama",
19
  "pythia": "Lyon28/Pythia",
@@ -38,44 +46,85 @@ class InferenceRequest(BaseModel):
38
  text: str
39
  max_length: int = 100
40
  temperature: float = 0.9
 
41
 
42
- def get_task(model_id: str):
 
43
  for task, models in TASK_MAP.items():
44
  if model_id in models:
45
  return task
46
  return "text-generation"
47
 
 
48
  @app.on_event("startup")
49
  async def load_models():
50
- # Initialize models (optional: pre-load critical models)
51
  app.state.pipelines = {}
52
- print("Models initialized in memory")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
 
 
 
 
 
 
 
 
 
 
54
  @app.post("/inference/{model_id}")
55
  async def model_inference(model_id: str, request: InferenceRequest):
56
  try:
 
57
  if model_id not in MODEL_MAP:
58
- raise HTTPException(status_code=404, detail="Model not found")
 
 
 
59
 
 
60
  task = get_task(model_id)
61
 
62
- # Load pipeline with caching
63
  if model_id not in app.state.pipelines:
64
  app.state.pipelines[model_id] = pipeline(
65
  task=task,
66
  model=MODEL_MAP[model_id],
67
- device_map="auto",
68
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
69
  )
 
70
 
71
  pipe = app.state.pipelines[model_id]
72
 
73
- # Process based on task
74
  if task == "text-generation":
75
  result = pipe(
76
  request.text,
77
  max_length=request.max_length,
78
- temperature=request.temperature
 
79
  )[0]['generated_text']
80
 
81
  elif task == "text-classification":
@@ -86,17 +135,19 @@ async def model_inference(model_id: str, request: InferenceRequest):
86
  }
87
 
88
  elif task == "text2text-generation":
89
- result = pipe(request.text)[0]['generated_text']
 
 
 
90
 
91
  return {"result": result}
92
 
93
  except Exception as e:
94
- raise HTTPException(status_code=500, detail=str(e))
95
-
96
- @app.get("/models")
97
- async def list_models():
98
- return {"available_models": list(MODEL_MAP.keys())}
99
 
100
- @app.get("/health")
101
- async def health_check():
102
- return {"status": "healthy"}
 
3
  from transformers import pipeline
4
  import torch
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from typing import Dict, Any
7
 
8
+ # Inisialisasi aplikasi FastAPI
9
+ app = FastAPI(
10
+ title="Lyon28 Model Inference API",
11
+ description="API untuk mengakses 11 model machine learning",
12
+ version="1.0.0"
13
+ )
14
 
15
+ # Konfigurasi CORS untuk frontend eksternal
16
  app.add_middleware(
17
  CORSMiddleware,
18
  allow_origins=["*"],
19
+ allow_credentials=True,
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
+ # Konfigurasi Model
25
  MODEL_MAP = {
26
  "tinny-llama": "Lyon28/Tinny-Llama",
27
  "pythia": "Lyon28/Pythia",
 
46
  text: str
47
  max_length: int = 100
48
  temperature: float = 0.9
49
+ top_p: float = 0.95
50
 
51
+ # Helper functions
52
+ def get_task(model_id: str) -> str:
53
  for task, models in TASK_MAP.items():
54
  if model_id in models:
55
  return task
56
  return "text-generation"
57
 
58
+ # Event startup untuk inisialisasi model
59
  @app.on_event("startup")
60
  async def load_models():
 
61
  app.state.pipelines = {}
62
+ print("🟢 Semua model siap digunakan!")
63
+
64
+ # Endpoint utama
65
+ @app.get("/")
66
+ async def root():
67
+ return {
68
+ "message": "Selamat datang di Lyon28 Model API",
69
+ "endpoints": {
70
+ "documentation": "/docs",
71
+ "model_list": "/models",
72
+ "health_check": "/health",
73
+ "inference": "/inference/{model_id}"
74
+ },
75
+ "total_models": len(MODEL_MAP)
76
+ }
77
+
78
+ # Endpoint untuk list model
79
+ @app.get("/models")
80
+ async def list_models():
81
+ return {
82
+ "available_models": list(MODEL_MAP.keys()),
83
+ "total_models": len(MODEL_MAP)
84
+ }
85
 
86
+ # Endpoint health check
87
+ @app.get("/health")
88
+ async def health_check():
89
+ return {
90
+ "status": "healthy",
91
+ "gpu_available": torch.cuda.is_available(),
92
+ "gpu_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU-only"
93
+ }
94
+
95
+ # Endpoint inference utama
96
  @app.post("/inference/{model_id}")
97
  async def model_inference(model_id: str, request: InferenceRequest):
98
  try:
99
+ # Validasi model ID
100
  if model_id not in MODEL_MAP:
101
+ raise HTTPException(
102
+ status_code=404,
103
+ detail=f"Model {model_id} tidak ditemukan. Cek /models untuk list model yang tersedia."
104
+ )
105
 
106
+ # Dapatkan task yang sesuai
107
  task = get_task(model_id)
108
 
109
+ # Load model jika belum ada di memory
110
  if model_id not in app.state.pipelines:
111
  app.state.pipelines[model_id] = pipeline(
112
  task=task,
113
  model=MODEL_MAP[model_id],
114
+ device=0 if torch.cuda.is_available() else -1,
115
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
116
  )
117
+ print(f"✅ Model {model_id} berhasil dimuat!")
118
 
119
  pipe = app.state.pipelines[model_id]
120
 
121
+ # Proses berdasarkan task
122
  if task == "text-generation":
123
  result = pipe(
124
  request.text,
125
  max_length=request.max_length,
126
+ temperature=request.temperature,
127
+ top_p=request.top_p
128
  )[0]['generated_text']
129
 
130
  elif task == "text-classification":
 
135
  }
136
 
137
  elif task == "text2text-generation":
138
+ result = pipe(
139
+ request.text,
140
+ max_length=request.max_length
141
+ )[0]['generated_text']
142
 
143
  return {"result": result}
144
 
145
  except Exception as e:
146
+ raise HTTPException(
147
+ status_code=500,
148
+ detail=f"Error processing request: {str(e)}"
149
+ )
 
150
 
151
+ if __name__ == "__main__":
152
+ import uvicorn
153
+ uvicorn.run(app, host="0.0.0.0", port=7860)