File size: 4,432 Bytes
9815afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from typing import Dict, Any

# Inisialisasi API
app = FastAPI(
    title="Lyon28 Multi-Model API",
    description="API serbaguna untuk 11 model Lyon28"
)

# --- Daftar model dan tugasnya ---
# Kita buat kamus (dictionary) agar mudah dipanggil.
# Ini juga membantu kita tahu pipeline apa yang harus digunakan untuk setiap model.
MODEL_MAPPING = {
    # Generative Models (Text Generation)
    "Tinny-Llama": {"id": "Lyon28/Tinny-Llama", "task": "text-generation"},
    "Pythia": {"id": "Lyon28/Pythia", "task": "text-generation"},
    "GPT-2": {"id": "Lyon28/GPT-2", "task": "text-generation"},
    "GPT-Neo": {"id": "Lyon28/GPT-Neo", "task": "text-generation"},
    "Distil_GPT-2": {"id": "Lyon28/Distil_GPT-2", "task": "text-generation"},
    "GPT-2-Tinny": {"id": "Lyon28/GPT-2-Tinny", "task": "text-generation"},
    
    # Text-to-Text Model
    "T5-Small": {"id": "Lyon28/T5-Small", "task": "text2text-generation"},
    
    # Fill-Mask Models
    "Bert-Tinny": {"id": "Lyon28/Bert-Tinny", "task": "fill-mask"},
    "Albert-Base-V2": {"id": "Lyon28/Albert-Base-V2", "task": "fill-mask"},
    "Distilbert-Base-Uncased": {"id": "Lyon28/Distilbert-Base-Uncased", "task": "fill-mask"},
    "Electra-Small": {"id": "Lyon28/Electra-Small", "task": "fill-mask"},
}

# --- Cache untuk menyimpan model yang sudah dimuat ---
# Ini penting! Kita tidak mau memuat model yang sama berulang-ulang.
# Ini akan menghemat waktu dan memori.
PIPELINE_CACHE = {}

def get_pipeline(model_name: str):
    """Fungsi untuk memuat model dari cache atau dari Hub jika belum ada."""
    if model_name in PIPELINE_CACHE:
        print(f"Mengambil model '{model_name}' dari cache.")
        return PIPELINE_CACHE[model_name]
    
    if model_name not in MODEL_MAPPING:
        raise HTTPException(status_code=404, detail=f"Model '{model_name}' tidak ditemukan.")
    
    model_info = MODEL_MAPPING[model_name]
    model_id = model_info["id"]
    task = model_info["task"]
    
    print(f"Memuat model '{model_name}' ({model_id}) untuk tugas '{task}'...")
    try:
        # device_map="auto" menggunakan accelerate untuk menempatkan model secara efisien
        pipe = pipeline(task, model=model_id, device_map="auto")
        PIPELINE_CACHE[model_name] = pipe
        print(f"Model '{model_name}' berhasil dimuat dan disimpan di cache.")
        return pipe
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Gagal memuat model '{model_name}': {str(e)}")


# --- Definisikan struktur request dari user ---
class InferenceRequest(BaseModel):
    model_name: str  # Nama kunci dari MODEL_MAPPING, misal: "Tinny-Llama"
    prompt: str
    parameters: Dict[str, Any] = {} # Parameter tambahan seperti max_length, temperature, dll.

@app.get("/")
def read_root():
    """Endpoint untuk mengecek status API dan daftar model yang tersedia."""
    return {
        "status": "API is running!",
        "available_models": list(MODEL_MAPPING.keys())
    }

@app.post("/invoke")
def invoke_model(request: InferenceRequest):
    """Endpoint utama untuk melakukan inferensi pada model yang dipilih."""
    try:
        # Ambil atau muat pipeline model
        pipe = get_pipeline(request.model_name)
        
        # Gabungkan prompt dengan parameter tambahan
        # Ini membuat API kita sangat fleksibel!
        result = pipe(request.prompt, **request.parameters)
        
        return {
            "model_used": request.model_name,
            "prompt": request.prompt,
            "parameters": request.parameters,
            "result": result
        }
    except HTTPException as e:
        # Meneruskan error yang sudah kita definisikan
        raise e
    except Exception as e:
        # Menangkap error lain yang mungkin terjadi saat inferensi
        raise HTTPException(status_code=500, detail=f"Terjadi error saat inferensi: {str(e)}")

# Saat aplikasi pertama kali dijalankan, kita bisa coba muat satu model populer
# untuk menghangatkan sistem (warm-up). Ini opsional.
@app.on_event("startup")
async def startup_event():
    print("API startup: Melakukan warm-up dengan memuat satu model awal...")
    try:
        get_pipeline("GPT-2-Tinny") # Pilih model yang kecil dan cepat
    except Exception as e:
        print(f"Gagal melakukan warm-up: {e}")