Lyon28's picture
Update main.py
621ada9 verified
raw
history blame
4.57 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline
import torch
from fastapi.middleware.cors import CORSMiddleware
from typing import Dict, Any
# Inisialisasi aplikasi FastAPI
app = FastAPI(
title="Lyon28 Model Inference API",
description="API untuk mengakses 11 model machine learning",
version="1.0.0"
)
# Konfigurasi CORS untuk frontend eksternal
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Konfigurasi Model
MODEL_MAP = {
"tinny-llama": "Lyon28/Tinny-Llama",
"pythia": "Lyon28/Pythia",
"bert-tinny": "Lyon28/Bert-Tinny",
"albert-base-v2": "Lyon28/Albert-Base-V2",
"t5-small": "Lyon28/T5-Small",
"gpt-2": "Lyon28/GPT-2",
"gpt-neo": "Lyon28/GPT-Neo",
"distilbert-base-uncased": "Lyon28/Distilbert-Base-Uncased",
"distil-gpt-2": "Lyon28/Distil_GPT-2",
"gpt-2-tinny": "Lyon28/GPT-2-Tinny",
"electra-small": "Lyon28/Electra-Small"
}
TASK_MAP = {
"text-generation": ["gpt-2", "gpt-neo", "distil-gpt-2", "gpt-2-tinny", "tinny-llama", "pythia"],
"text-classification": ["bert-tinny", "albert-base-v2", "distilbert-base-uncased", "electra-small"],
"text2text-generation": ["t5-small"]
}
class InferenceRequest(BaseModel):
text: str
max_length: int = 100
temperature: float = 0.9
top_p: float = 0.95
# Helper functions
def get_task(model_id: str) -> str:
for task, models in TASK_MAP.items():
if model_id in models:
return task
return "text-generation"
# Event startup untuk inisialisasi model
@app.on_event("startup")
async def load_models():
app.state.pipelines = {}
print("🟢 Semua model siap digunakan!")
# Endpoint utama
@app.get("/")
async def root():
return {
"message": "Selamat datang di Lyon28 Model API",
"endpoints": {
"documentation": "/docs",
"model_list": "/models",
"health_check": "/health",
"inference": "/inference/{model_id}"
},
"total_models": len(MODEL_MAP)
}
# Endpoint untuk list model
@app.get("/models")
async def list_models():
return {
"available_models": list(MODEL_MAP.keys()),
"total_models": len(MODEL_MAP)
}
# Endpoint health check
@app.get("/health")
async def health_check():
return {
"status": "healthy",
"gpu_available": torch.cuda.is_available(),
"gpu_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU-only"
}
# Endpoint inference utama
@app.post("/inference/{model_id}")
async def model_inference(model_id: str, request: InferenceRequest):
try:
# Validasi model ID
if model_id not in MODEL_MAP:
raise HTTPException(
status_code=404,
detail=f"Model {model_id} tidak ditemukan. Cek /models untuk list model yang tersedia."
)
# Dapatkan task yang sesuai
task = get_task(model_id)
# Load model jika belum ada di memory
if model_id not in app.state.pipelines:
app.state.pipelines[model_id] = pipeline(
task=task,
model=MODEL_MAP[model_id],
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
print(f"✅ Model {model_id} berhasil dimuat!")
pipe = app.state.pipelines[model_id]
# Proses berdasarkan task
if task == "text-generation":
result = pipe(
request.text,
max_length=request.max_length,
temperature=request.temperature,
top_p=request.top_p
)[0]['generated_text']
elif task == "text-classification":
output = pipe(request.text)[0]
result = {
"label": output['label'],
"confidence": round(output['score'], 4)
}
elif task == "text2text-generation":
result = pipe(
request.text,
max_length=request.max_length
)[0]['generated_text']
return {"result": result}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error processing request: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)