Lyon28's picture
Rename app.py to main.py
9815afa verified
raw
history blame
4.43 kB
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any
# Inisialisasi API
app = FastAPI(
title="Lyon28 Multi-Model API",
description="API serbaguna untuk 11 model Lyon28"
)
# --- Daftar model dan tugasnya ---
# Kita buat kamus (dictionary) agar mudah dipanggil.
# Ini juga membantu kita tahu pipeline apa yang harus digunakan untuk setiap model.
MODEL_MAPPING = {
# Generative Models (Text Generation)
"Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"},
"Pythia": {"id": "Lyon28/Pythia", "task": "text-generation"},
"GPT-2": {"id": "Lyon28/GPT-2", "task": "text-generation"},
"GPT-Neo": {"id": "Lyon28/GPT-Neo", "task": "text-generation"},
"Distil_GPT-2": {"id": "Lyon28/Distil_GPT-2", "task": "text-generation"},
"GPT-2-Tinny": {"id": "Lyon28/GPT-2-Tinny", "task": "text-generation"},
# Text-to-Text Model
"T5-Small": {"id": "Lyon28/T5-Small", "task": "text2text-generation"},
# Fill-Mask Models
"Bert-Tinny": {"id": "Lyon28/Bert-Tinny", "task": "fill-mask"},
"Albert-Base-V2": {"id": "Lyon28/Albert-Base-V2", "task": "fill-mask"},
"Distilbert-Base-Uncased": {"id": "Lyon28/Distilbert-Base-Uncased", "task": "fill-mask"},
"Electra-Small": {"id": "Lyon28/Electra-Small", "task": "fill-mask"},
}
# --- Cache untuk menyimpan model yang sudah dimuat ---
# Ini penting! Kita tidak mau memuat model yang sama berulang-ulang.
# Ini akan menghemat waktu dan memori.
PIPELINE_CACHE = {}
def get_pipeline(model_name: str):
"""Fungsi untuk memuat model dari cache atau dari Hub jika belum ada."""
if model_name in PIPELINE_CACHE:
print(f"Mengambil model '{model_name}' dari cache.")
return PIPELINE_CACHE[model_name]
if model_name not in MODEL_MAPPING:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ditemukan.")
model_info = MODEL_MAPPING[model_name]
model_id = model_info["id"]
task = model_info["task"]
print(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...")
try:
# device_map="auto" menggunakan accelerate untuk menempatkan model secara efisien
pipe = pipeline(task, model=model_id, device_map="auto")
PIPELINE_CACHE[model_name] = pipe
print(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.")
return pipe
except Exception as e:
raise HTTPException(status_code=500, detail=f"Gagal memuat model '{model_name}': {str(e)}")
# --- Definisikan struktur request dari user ---
class InferenceRequest(BaseModel):
model_name: str # Nama kunci dari MODEL_MAPPING, misal: "Tinny-Llama"
prompt: str
parameters: Dict[str, Any] = {} # Parameter tambahan seperti max_length, temperature, dll.
@app.get("/")
def read_root():
"""Endpoint untuk mengecek status API dan daftar model yang tersedia."""
return {
"status": "API is running!",
"available_models": list(MODEL_MAPPING.keys())
}
@app.post("/invoke")
def invoke_model(request: InferenceRequest):
"""Endpoint utama untuk melakukan inferensi pada model yang dipilih."""
try:
# Ambil atau muat pipeline model
pipe = get_pipeline(request.model_name)
# Gabungkan prompt dengan parameter tambahan
# Ini membuat API kita sangat fleksibel!
result = pipe(request.prompt, **request.parameters)
return {
"model_used": request.model_name,
"prompt": request.prompt,
"parameters": request.parameters,
"result": result
}
except HTTPException as e:
# Meneruskan error yang sudah kita definisikan
raise e
except Exception as e:
# Menangkap error lain yang mungkin terjadi saat inferensi
raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}")
# Saat aplikasi pertama kali dijalankan, kita bisa coba muat satu model populer
# untuk menghangatkan sistem (warm-up). Ini opsional.
@app.on_event("startup")
async def startup_event():
print("API startup: Melakukan warm-up dengan memuat satu model awal...")
try:
get_pipeline("GPT-2-Tinny") # Pilih model yang kecil dan cepat
except Exception as e:
print(f"Gagal melakukan warm-up: {e}")