Spaces:
Running
Running
import torch | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from typing import Dict, Any | |
# Inisialisasi API | |
app = FastAPI( | |
title="Lyon28 Multi-Model API", | |
description="API serbaguna untuk 11 model Lyon28" | |
) | |
# --- Daftar model dan tugasnya --- | |
# Kita buat kamus (dictionary) agar mudah dipanggil. | |
# Ini juga membantu kita tahu pipeline apa yang harus digunakan untuk setiap model. | |
MODEL_MAPPING = { | |
# Generative Models (Text Generation) | |
"Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"}, | |
"Pythia": {"id": "Lyon28/Pythia", "task": "text-generation"}, | |
"GPT-2": {"id": "Lyon28/GPT-2", "task": "text-generation"}, | |
"GPT-Neo": {"id": "Lyon28/GPT-Neo", "task": "text-generation"}, | |
"Distil_GPT-2": {"id": "Lyon28/Distil_GPT-2", "task": "text-generation"}, | |
"GPT-2-Tinny": {"id": "Lyon28/GPT-2-Tinny", "task": "text-generation"}, | |
# Text-to-Text Model | |
"T5-Small": {"id": "Lyon28/T5-Small", "task": "text2text-generation"}, | |
# Fill-Mask Models | |
"Bert-Tinny": {"id": "Lyon28/Bert-Tinny", "task": "fill-mask"}, | |
"Albert-Base-V2": {"id": "Lyon28/Albert-Base-V2", "task": "fill-mask"}, | |
"Distilbert-Base-Uncased": {"id": "Lyon28/Distilbert-Base-Uncased", "task": "fill-mask"}, | |
"Electra-Small": {"id": "Lyon28/Electra-Small", "task": "fill-mask"}, | |
} | |
# --- Cache untuk menyimpan model yang sudah dimuat --- | |
# Ini penting! Kita tidak mau memuat model yang sama berulang-ulang. | |
# Ini akan menghemat waktu dan memori. | |
PIPELINE_CACHE = {} | |
def get_pipeline(model_name: str): | |
"""Fungsi untuk memuat model dari cache atau dari Hub jika belum ada.""" | |
if model_name in PIPELINE_CACHE: | |
print(f"Mengambil model '{model_name}' dari cache.") | |
return PIPELINE_CACHE[model_name] | |
if model_name not in MODEL_MAPPING: | |
raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ditemukan.") | |
model_info = MODEL_MAPPING[model_name] | |
model_id = model_info["id"] | |
task = model_info["task"] | |
print(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...") | |
try: | |
# device_map="auto" menggunakan accelerate untuk menempatkan model secara efisien | |
pipe = pipeline(task, model=model_id, device_map="auto") | |
PIPELINE_CACHE[model_name] = pipe | |
print(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.") | |
return pipe | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Gagal memuat model '{model_name}': {str(e)}") | |
# --- Definisikan struktur request dari user --- | |
class InferenceRequest(BaseModel): | |
model_name: str # Nama kunci dari MODEL_MAPPING, misal: "Tinny-Llama" | |
prompt: str | |
parameters: Dict[str, Any] = {} # Parameter tambahan seperti max_length, temperature, dll. | |
def read_root(): | |
"""Endpoint untuk mengecek status API dan daftar model yang tersedia.""" | |
return { | |
"status": "API is running!", | |
"available_models": list(MODEL_MAPPING.keys()) | |
} | |
def invoke_model(request: InferenceRequest): | |
"""Endpoint utama untuk melakukan inferensi pada model yang dipilih.""" | |
try: | |
# Ambil atau muat pipeline model | |
pipe = get_pipeline(request.model_name) | |
# Gabungkan prompt dengan parameter tambahan | |
# Ini membuat API kita sangat fleksibel! | |
result = pipe(request.prompt, **request.parameters) | |
return { | |
"model_used": request.model_name, | |
"prompt": request.prompt, | |
"parameters": request.parameters, | |
"result": result | |
} | |
except HTTPException as e: | |
# Meneruskan error yang sudah kita definisikan | |
raise e | |
except Exception as e: | |
# Menangkap error lain yang mungkin terjadi saat inferensi | |
raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}") | |
# Saat aplikasi pertama kali dijalankan, kita bisa coba muat satu model populer | |
# untuk menghangatkan sistem (warm-up). Ini opsional. | |
async def startup_event(): | |
print("API startup: Melakukan warm-up dengan memuat satu model awal...") | |
try: | |
get_pipeline("GPT-2-Tinny") # Pilih model yang kecil dan cepat | |
except Exception as e: | |
print(f"Gagal melakukan warm-up: {e}") |