Spaces:
Sleeping
Sleeping
File size: 4,432 Bytes
9815afa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any
# Inisialisasi API
app = FastAPI(
title="Lyon28 Multi-Model API",
description="API serbaguna untuk 11 model Lyon28"
)
# --- Daftar model dan tugasnya ---
# Kita buat kamus (dictionary) agar mudah dipanggil.
# Ini juga membantu kita tahu pipeline apa yang harus digunakan untuk setiap model.
MODEL_MAPPING = {
# Generative Models (Text Generation)
"Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"},
"Pythia": {"id": "Lyon28/Pythia", "task": "text-generation"},
"GPT-2": {"id": "Lyon28/GPT-2", "task": "text-generation"},
"GPT-Neo": {"id": "Lyon28/GPT-Neo", "task": "text-generation"},
"Distil_GPT-2": {"id": "Lyon28/Distil_GPT-2", "task": "text-generation"},
"GPT-2-Tinny": {"id": "Lyon28/GPT-2-Tinny", "task": "text-generation"},
# Text-to-Text Model
"T5-Small": {"id": "Lyon28/T5-Small", "task": "text2text-generation"},
# Fill-Mask Models
"Bert-Tinny": {"id": "Lyon28/Bert-Tinny", "task": "fill-mask"},
"Albert-Base-V2": {"id": "Lyon28/Albert-Base-V2", "task": "fill-mask"},
"Distilbert-Base-Uncased": {"id": "Lyon28/Distilbert-Base-Uncased", "task": "fill-mask"},
"Electra-Small": {"id": "Lyon28/Electra-Small", "task": "fill-mask"},
}
# --- Cache untuk menyimpan model yang sudah dimuat ---
# Ini penting! Kita tidak mau memuat model yang sama berulang-ulang.
# Ini akan menghemat waktu dan memori.
PIPELINE_CACHE = {}
def get_pipeline(model_name: str):
"""Fungsi untuk memuat model dari cache atau dari Hub jika belum ada."""
if model_name in PIPELINE_CACHE:
print(f"Mengambil model '{model_name}' dari cache.")
return PIPELINE_CACHE[model_name]
if model_name not in MODEL_MAPPING:
raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ditemukan.")
model_info = MODEL_MAPPING[model_name]
model_id = model_info["id"]
task = model_info["task"]
print(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...")
try:
# device_map="auto" menggunakan accelerate untuk menempatkan model secara efisien
pipe = pipeline(task, model=model_id, device_map="auto")
PIPELINE_CACHE[model_name] = pipe
print(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.")
return pipe
except Exception as e:
raise HTTPException(status_code=500, detail=f"Gagal memuat model '{model_name}': {str(e)}")
# --- Definisikan struktur request dari user ---
class InferenceRequest(BaseModel):
model_name: str # Nama kunci dari MODEL_MAPPING, misal: "Tinny-Llama"
prompt: str
parameters: Dict[str, Any] = {} # Parameter tambahan seperti max_length, temperature, dll.
@app.get("/")
def read_root():
"""Endpoint untuk mengecek status API dan daftar model yang tersedia."""
return {
"status": "API is running!",
"available_models": list(MODEL_MAPPING.keys())
}
@app.post("/invoke")
def invoke_model(request: InferenceRequest):
"""Endpoint utama untuk melakukan inferensi pada model yang dipilih."""
try:
# Ambil atau muat pipeline model
pipe = get_pipeline(request.model_name)
# Gabungkan prompt dengan parameter tambahan
# Ini membuat API kita sangat fleksibel!
result = pipe(request.prompt, **request.parameters)
return {
"model_used": request.model_name,
"prompt": request.prompt,
"parameters": request.parameters,
"result": result
}
except HTTPException as e:
# Meneruskan error yang sudah kita definisikan
raise e
except Exception as e:
# Menangkap error lain yang mungkin terjadi saat inferensi
raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}")
# Saat aplikasi pertama kali dijalankan, kita bisa coba muat satu model populer
# untuk menghangatkan sistem (warm-up). Ini opsional.
@app.on_event("startup")
async def startup_event():
print("API startup: Melakukan warm-up dengan memuat satu model awal...")
try:
get_pipeline("GPT-2-Tinny") # Pilih model yang kecil dan cepat
except Exception as e:
print(f"Gagal melakukan warm-up: {e}") |