from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import pipeline import torch from fastapi.middleware.cors import CORSMiddleware from typing import Dict, Any # Inisialisasi aplikasi FastAPI app = FastAPI( title="Lyon28 Model Inference API", description="API untuk mengakses 11 model machine learning", version="1.0.0" ) # Konfigurasi CORS untuk frontend eksternal app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Konfigurasi Model MODEL_MAP = { "tinny-llama": "Lyon28/Tinny-Llama", "pythia": "Lyon28/Pythia", "bert-tinny": "Lyon28/Bert-Tinny", "albert-base-v2": "Lyon28/Albert-Base-V2", "t5-small": "Lyon28/T5-Small", "gpt-2": "Lyon28/GPT-2", "gpt-neo": "Lyon28/GPT-Neo", "distilbert-base-uncased": "Lyon28/Distilbert-Base-Uncased", "distil-gpt-2": "Lyon28/Distil_GPT-2", "gpt-2-tinny": "Lyon28/GPT-2-Tinny", "electra-small": "Lyon28/Electra-Small" } TASK_MAP = { "text-generation": ["gpt-2", "gpt-neo", "distil-gpt-2", "gpt-2-tinny", "tinny-llama", "pythia"], "text-classification": ["bert-tinny", "albert-base-v2", "distilbert-base-uncased", "electra-small"], "text2text-generation": ["t5-small"] } class InferenceRequest(BaseModel): text: str max_length: int = 100 temperature: float = 0.9 top_p: float = 0.95 # Helper functions def get_task(model_id: str) -> str: for task, models in TASK_MAP.items(): if model_id in models: return task return "text-generation" # Event startup untuk inisialisasi model @app.on_event("startup") async def load_models(): app.state.pipelines = {} print("🟢 Semua model siap digunakan!") # Endpoint utama @app.get("/") async def root(): return { "message": "Selamat datang di Lyon28 Model API", "endpoints": { "documentation": "/docs", "model_list": "/models", "health_check": "/health", "inference": "/inference/{model_id}" }, "total_models": len(MODEL_MAP) } # Endpoint untuk list model @app.get("/models") async def list_models(): return { "available_models": list(MODEL_MAP.keys()), "total_models": len(MODEL_MAP) } # Endpoint health check @app.get("/health") async def health_check(): return { "status": "healthy", "gpu_available": torch.cuda.is_available(), "gpu_type": torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU-only" } # Endpoint inference utama @app.post("/inference/{model_id}") async def model_inference(model_id: str, request: InferenceRequest): try: # Validasi model ID if model_id not in MODEL_MAP: raise HTTPException( status_code=404, detail=f"Model {model_id} tidak ditemukan. Cek /models untuk list model yang tersedia." ) # Dapatkan task yang sesuai task = get_task(model_id) # Load model jika belum ada di memory if model_id not in app.state.pipelines: app.state.pipelines[model_id] = pipeline( task=task, model=MODEL_MAP[model_id], device=0 if torch.cuda.is_available() else -1, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 ) print(f"✅ Model {model_id} berhasil dimuat!") pipe = app.state.pipelines[model_id] # Proses berdasarkan task if task == "text-generation": result = pipe( request.text, max_length=request.max_length, temperature=request.temperature, top_p=request.top_p )[0]['generated_text'] elif task == "text-classification": output = pipe(request.text)[0] result = { "label": output['label'], "confidence": round(output['score'], 4) } elif task == "text2text-generation": result = pipe( request.text, max_length=request.max_length )[0]['generated_text'] return {"result": result} except Exception as e: raise HTTPException( status_code=500, detail=f"Error processing request: {str(e)}" ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)