Spaces:
Running
Running
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline | |
import torch | |
from fastapi.middleware.cors import CORSMiddleware | |
from typing import Dict, Any, Optional | |
import os # Import os module | |
# Inisialisasi aplikasi FastAPI | |
app = FastAPI( | |
title="LyonPoy Model Inference API", | |
description="API untuk mengakses 11 model machine learning", | |
version="1.0.0" | |
) | |
# Konfigurasi CORS untuk frontend eksternal | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Konfigurasi Model | |
MODEL_MAP = { | |
"tinny-llama": "Lyon28/Tinny-Llama", | |
"pythia": "Lyon28/Pythia", | |
"bert-tinny": "Lyon28/Bert-Tinny", | |
"albert-base-v2": "Lyon28/Albert-Base-V2", | |
"t5-small": "Lyon28/T5-Small", | |
"gpt-2": "Lyon28/GPT-2", | |
"gpt-neo": "Lyon28/GPT-Neo", | |
"distilbert-base-uncased": "Lyon28/Distilbert-Base-Uncased", | |
"distil-gpt-2": "Lyon28/Distil_GPT-2", | |
"gpt-2-tinny": "Lyon28/GPT-2-Tinny", | |
"electra-small": "Lyon28/Electra-Small" | |
} | |
TASK_MAP = { | |
"text-generation": ["gpt-2", "gpt-neo", "distil-gpt-2", "gpt-2-tinny", "tinny-llama", "pythia"], | |
"text-classification": ["bert-tinny", "albert-base-v2", "distilbert-base-uncased", "electra-small"], | |
"text2text-generation": ["t5-small"] | |
} | |
class InferenceRequest(BaseModel): | |
text: str | |
model_id: Optional[str] = "gpt-2" # Default model | |
max_length: int = 100 | |
temperature: float = 0.9 | |
top_p: float = 0.95 | |
# Helper functions | |
def get_task(model_id: str) -> str: | |
for task, models in TASK_MAP.items(): | |
if model_id in models: | |
return task | |
# Default to text-generation if not found (or raise an error) | |
return "text-generation" | |
# Event startup untuk inisialisasi model | |
async def load_models(): | |
app.state.pipelines = {} | |
print("🟢 Semua model siap digunakan!") | |
# Menyetel HF_HOME untuk mengatasi masalah izin cache | |
os.environ['HF_HOME'] = '/tmp/.cache/huggingface' | |
os.makedirs(os.environ['HF_HOME'], exist_ok=True) | |
# Endpoint utama | |
async def root(): | |
return { | |
"message": "Selamat datang di Lyon28 Model API", | |
"endpoints": { | |
"documentation": "/docs", | |
"model_list": "/models", | |
"health_check": "/health", | |
"inference_with_model": "/inference/{model_id}", | |
"inference_general": "/inference" | |
}, | |
"total_models": len(MODEL_MAP), | |
"usage_examples": { | |
"specific_model": "POST /inference/gpt-2 with JSON body", | |
"general_inference": "POST /inference with model_id in JSON body" | |
} | |
} | |
# Endpoint untuk list model | |
async def list_models(): | |
return { | |
"available_models": list(MODEL_MAP.keys()), | |
"total_models": len(MODEL_MAP) | |
} | |
# Endpoint health check | |
async def health_check(): | |
return { | |
"status": "healthy", | |
"gpu_available": torch.cuda.is_available(), | |
"gpu_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU-only" | |
} | |
# NEW: General inference endpoint (handles POST /inference) | |
async def general_inference(request: InferenceRequest): | |
""" | |
General inference endpoint that accepts model_id in the request body | |
""" | |
return await process_inference(request.model_id, request) | |
# Endpoint inference dengan model_id di path | |
async def model_inference(model_id: str, request: InferenceRequest): | |
""" | |
Specific model inference endpoint with model_id in path | |
""" | |
return await process_inference(model_id, request) | |
# Shared inference processing function | |
async def process_inference(model_id: str, request: InferenceRequest): | |
try: | |
# Pastikan model_id dalam lowercase agar sesuai dengan MODEL_MAP | |
model_id = model_id.lower() | |
# Validasi model ID | |
if model_id not in MODEL_MAP: | |
available_models = ", ".join(MODEL_MAP.keys()) | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{model_id}' tidak ditemukan. Model yang tersedia: {available_models}" | |
) | |
# Dapatkan task yang sesuai | |
task = get_task(model_id) | |
# Load model jika belum ada di memory | |
if model_id not in app.state.pipelines: | |
print(f"⏳ Memuat model {model_id} untuk task {task}...") | |
# Menggunakan device=-1 (CPU) sebagai default yang aman | |
# Jika Anda yakin Hugging Face Space Anda memiliki GPU, gunakan device=0 | |
device_to_use = 0 if torch.cuda.is_available() else -1 | |
# Menyesuaikan dtype berdasarkan device | |
dtype_to_use = torch.float16 if torch.cuda.is_available() else torch.float32 | |
try: | |
app.state.pipelines[model_id] = pipeline( | |
task=task, | |
model=MODEL_MAP[model_id], | |
device=device_to_use, | |
torch_dtype=dtype_to_use | |
) | |
print(f"✅ Model {model_id} berhasil dimuat!") | |
except Exception as load_error: | |
print(f"❌ Gagal memuat model {model_id}: {load_error}") | |
raise HTTPException( | |
status_code=503, | |
detail=f"Gagal memuat model {model_id}. Coba lagi nanti." | |
) | |
pipe = app.state.pipelines[model_id] | |
# Proses berdasarkan task | |
if task == "text-generation": | |
result = pipe( | |
request.text, | |
max_length=request.max_length, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
do_sample=True | |
)[0]['generated_text'] | |
elif task == "text-classification": | |
# Untuk text-classification, output adalah list of dict, kita ambil yang pertama | |
output = pipe(request.text)[0] | |
result = { | |
"label": output['label'], | |
"confidence": round(output['score'], 4) | |
} | |
elif task == "text2text-generation": | |
# Untuk text2text-generation, output juga list of dict | |
result = pipe( | |
request.text, | |
max_length=request.max_length | |
)[0]['generated_text'] | |
else: | |
# Fallback untuk task yang tidak terduga, meski harusnya terhandle oleh get_task | |
raise HTTPException( | |
status_code=500, | |
detail=f"Tugas ({task}) untuk model {model_id} tidak didukung atau tidak dikenali." | |
) | |
return { | |
"result": result, | |
"model_used": model_id, | |
"task": task, | |
"status": "success" | |
} | |
except HTTPException as he: | |
# Re-raise HTTP exceptions | |
raise he | |
except Exception as e: | |
# Log error lebih detail untuk debugging | |
print(f"‼️ Error saat memproses model {model_id}: {e}") | |
import traceback | |
traceback.print_exc() # Mencetak full traceback ke log | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error processing request: {str(e)}. Cek log server untuk detail." | |
) | |
# Error handler untuk 404 | |
async def not_found_handler(request, exc): | |
return { | |
"error": "Endpoint tidak ditemukan", | |
"available_endpoints": [ | |
"GET /", | |
"GET /models", | |
"GET /health", | |
"POST /inference", | |
"POST /inference/{model_id}" | |
], | |
"tip": "Gunakan /docs untuk dokumentasi lengkap" | |
} |