Lyon28 commited on
Commit
7055a09
·
verified ·
1 Parent(s): d3d3301

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -19
app.py CHANGED
@@ -3,7 +3,7 @@ from pydantic import BaseModel
3
  from transformers import pipeline
4
  import torch
5
  from fastapi.middleware.cors import CORSMiddleware
6
- from typing import Dict, Any
7
  import os # Import os module
8
 
9
  # Inisialisasi aplikasi FastAPI
@@ -45,6 +45,7 @@ TASK_MAP = {
45
 
46
  class InferenceRequest(BaseModel):
47
  text: str
 
48
  max_length: int = 100
49
  temperature: float = 0.9
50
  top_p: float = 0.95
@@ -66,7 +67,6 @@ async def load_models():
66
  os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
67
  os.makedirs(os.environ['HF_HOME'], exist_ok=True)
68
 
69
-
70
  # Endpoint utama
71
  @app.get("/")
72
  async def root():
@@ -76,9 +76,14 @@ async def root():
76
  "documentation": "/docs",
77
  "model_list": "/models",
78
  "health_check": "/health",
79
- "inference": "/inference/{model_id}"
 
80
  },
81
- "total_models": len(MODEL_MAP)
 
 
 
 
82
  }
83
 
84
  # Endpoint untuk list model
@@ -98,18 +103,34 @@ async def health_check():
98
  "gpu_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU-only"
99
  }
100
 
101
- # Endpoint inference utama
 
 
 
 
 
 
 
 
102
  @app.post("/inference/{model_id}")
103
  async def model_inference(model_id: str, request: InferenceRequest):
 
 
 
 
 
 
 
104
  try:
105
  # Pastikan model_id dalam lowercase agar sesuai dengan MODEL_MAP
106
  model_id = model_id.lower()
107
 
108
  # Validasi model ID
109
  if model_id not in MODEL_MAP:
 
110
  raise HTTPException(
111
  status_code=404,
112
- detail=f"Model '{model_id}' tidak ditemukan. Cek /models untuk list model yang tersedia."
113
  )
114
 
115
  # Dapatkan task yang sesuai
@@ -124,13 +145,20 @@ async def model_inference(model_id: str, request: InferenceRequest):
124
  # Menyesuaikan dtype berdasarkan device
125
  dtype_to_use = torch.float16 if torch.cuda.is_available() else torch.float32
126
 
127
- app.state.pipelines[model_id] = pipeline(
128
- task=task,
129
- model=MODEL_MAP[model_id],
130
- device=device_to_use,
131
- torch_dtype=dtype_to_use
132
- )
133
- print(f"✅ Model {model_id} berhasil dimuat!")
 
 
 
 
 
 
 
134
 
135
  pipe = app.state.pipelines[model_id]
136
 
@@ -140,7 +168,8 @@ async def model_inference(model_id: str, request: InferenceRequest):
140
  request.text,
141
  max_length=request.max_length,
142
  temperature=request.temperature,
143
- top_p=request.top_p
 
144
  )[0]['generated_text']
145
 
146
  elif task == "text-classification":
@@ -165,8 +194,16 @@ async def model_inference(model_id: str, request: InferenceRequest):
165
  detail=f"Tugas ({task}) untuk model {model_id} tidak didukung atau tidak dikenali."
166
  )
167
 
168
- return {"result": result}
 
 
 
 
 
169
 
 
 
 
170
  except Exception as e:
171
  # Log error lebih detail untuk debugging
172
  print(f"‼️ Error saat memproses model {model_id}: {e}")
@@ -178,7 +215,17 @@ async def model_inference(model_id: str, request: InferenceRequest):
178
  detail=f"Error processing request: {str(e)}. Cek log server untuk detail."
179
  )
180
 
181
- # Ini tidak perlu dijalankan secara langsung di Hugging Face Spaces karena Uvicorn akan menjalankannya
182
- # if __name__ == "__main__":
183
- # import uvicorn
184
- # uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
 
 
 
 
3
  from transformers import pipeline
4
  import torch
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from typing import Dict, Any, Optional
7
  import os # Import os module
8
 
9
  # Inisialisasi aplikasi FastAPI
 
45
 
46
  class InferenceRequest(BaseModel):
47
  text: str
48
+ model_id: Optional[str] = "gpt-2" # Default model
49
  max_length: int = 100
50
  temperature: float = 0.9
51
  top_p: float = 0.95
 
67
  os.environ['HF_HOME'] = '/tmp/.cache/huggingface'
68
  os.makedirs(os.environ['HF_HOME'], exist_ok=True)
69
 
 
70
  # Endpoint utama
71
  @app.get("/")
72
  async def root():
 
76
  "documentation": "/docs",
77
  "model_list": "/models",
78
  "health_check": "/health",
79
+ "inference_with_model": "/inference/{model_id}",
80
+ "inference_general": "/inference"
81
  },
82
+ "total_models": len(MODEL_MAP),
83
+ "usage_examples": {
84
+ "specific_model": "POST /inference/gpt-2 with JSON body",
85
+ "general_inference": "POST /inference with model_id in JSON body"
86
+ }
87
  }
88
 
89
  # Endpoint untuk list model
 
103
  "gpu_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU-only"
104
  }
105
 
106
+ # NEW: General inference endpoint (handles POST /inference)
107
+ @app.post("/inference")
108
+ async def general_inference(request: InferenceRequest):
109
+ """
110
+ General inference endpoint that accepts model_id in the request body
111
+ """
112
+ return await process_inference(request.model_id, request)
113
+
114
+ # Endpoint inference dengan model_id di path
115
  @app.post("/inference/{model_id}")
116
  async def model_inference(model_id: str, request: InferenceRequest):
117
+ """
118
+ Specific model inference endpoint with model_id in path
119
+ """
120
+ return await process_inference(model_id, request)
121
+
122
+ # Shared inference processing function
123
+ async def process_inference(model_id: str, request: InferenceRequest):
124
  try:
125
  # Pastikan model_id dalam lowercase agar sesuai dengan MODEL_MAP
126
  model_id = model_id.lower()
127
 
128
  # Validasi model ID
129
  if model_id not in MODEL_MAP:
130
+ available_models = ", ".join(MODEL_MAP.keys())
131
  raise HTTPException(
132
  status_code=404,
133
+ detail=f"Model '{model_id}' tidak ditemukan. Model yang tersedia: {available_models}"
134
  )
135
 
136
  # Dapatkan task yang sesuai
 
145
  # Menyesuaikan dtype berdasarkan device
146
  dtype_to_use = torch.float16 if torch.cuda.is_available() else torch.float32
147
 
148
+ try:
149
+ app.state.pipelines[model_id] = pipeline(
150
+ task=task,
151
+ model=MODEL_MAP[model_id],
152
+ device=device_to_use,
153
+ torch_dtype=dtype_to_use
154
+ )
155
+ print(f"✅ Model {model_id} berhasil dimuat!")
156
+ except Exception as load_error:
157
+ print(f"❌ Gagal memuat model {model_id}: {load_error}")
158
+ raise HTTPException(
159
+ status_code=503,
160
+ detail=f"Gagal memuat model {model_id}. Coba lagi nanti."
161
+ )
162
 
163
  pipe = app.state.pipelines[model_id]
164
 
 
168
  request.text,
169
  max_length=request.max_length,
170
  temperature=request.temperature,
171
+ top_p=request.top_p,
172
+ do_sample=True
173
  )[0]['generated_text']
174
 
175
  elif task == "text-classification":
 
194
  detail=f"Tugas ({task}) untuk model {model_id} tidak didukung atau tidak dikenali."
195
  )
196
 
197
+ return {
198
+ "result": result,
199
+ "model_used": model_id,
200
+ "task": task,
201
+ "status": "success"
202
+ }
203
 
204
+ except HTTPException as he:
205
+ # Re-raise HTTP exceptions
206
+ raise he
207
  except Exception as e:
208
  # Log error lebih detail untuk debugging
209
  print(f"‼️ Error saat memproses model {model_id}: {e}")
 
215
  detail=f"Error processing request: {str(e)}. Cek log server untuk detail."
216
  )
217
 
218
+ # Error handler untuk 404
219
+ @app.exception_handler(404)
220
+ async def not_found_handler(request, exc):
221
+ return {
222
+ "error": "Endpoint tidak ditemukan",
223
+ "available_endpoints": [
224
+ "GET /",
225
+ "GET /models",
226
+ "GET /health",
227
+ "POST /inference",
228
+ "POST /inference/{model_id}"
229
+ ],
230
+ "tip": "Gunakan /docs untuk dokumentasi lengkap"
231
+ }