Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from transformers import pipeline | |
# 1. Definisikan "otak" aplikasi (sama persis seperti di Gradio) | |
MODEL_CONFIG = { | |
# (Salin MODEL_CONFIG lengkap dari atas ke sini) | |
"Lyon28/GPT-2-Tinny": {"task": "text-generation", "display_name": "GPT-2 (Tiny)"}, | |
"Lyon28/GPT-2": {"task": "text-generation", "display_name": "GPT-2"}, | |
"Lyon28/Distil_GPT-2": {"task": "text-generation", "display_name": "DistilGPT-2"}, | |
"Lyon28/GPT-Neo": {"task": "text-generation", "display_name": "GPT-Neo"}, | |
"Lyon28/Pythia": {"task": "text-generation", "display_name": "Pythia"}, | |
"Lyon28/Tinny-Llama": {"task": "text-generation", "display_name": "Tinny-Llama"}, | |
"Lyon28/Bert-Tinny": {"task": "fill-mask", "display_name": "BERT (Tiny)"}, | |
"Lyon28/Distilbert-Base-Uncased": {"task": "fill-mask", "display_name": "DistilBERT"}, | |
"Lyon28/Albert-Base-V2": {"task": "fill-mask", "display_name": "Albert v2"}, | |
"Lyon28/Electra-Small": {"task": "fill-mask", "display_name": "Electra (Small)"}, | |
"Lyon28/T5-Small": {"task": "text2text-generation", "display_name": "T5 (Small)"}, | |
} | |
# 2. Buat "gudang" model (sama persis) | |
loaded_pipelines = {} | |
# 3. Definisikan format request yang diterima | |
class InferenceRequest(BaseModel): | |
model_id: str | |
prompt: str | |
app = FastAPI() | |
def read_root(): | |
return {"message": "Smart Inference API is running. Use the /inference endpoint."} | |
def smart_inference(request: InferenceRequest): | |
model_id = request.model_id | |
# Validasi: Cek apakah model_id ada di config kita | |
if model_id not in MODEL_CONFIG: | |
raise HTTPException(status_code=400, detail=f"Model '{model_id}' tidak valid atau tidak didukung.") | |
task = MODEL_CONFIG[model_id]["task"] | |
# Cek "gudang" (logika caching yang sama) | |
if model_id not in loaded_pipelines: | |
print(f"Memuat model: {model_id} untuk task: {task}...") | |
try: | |
pipe = pipeline(task, model=model_id, device=-1) | |
loaded_pipelines[model_id] = pipe | |
print("Model berhasil dimuat.") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Gagal memuat model: {str(e)}") | |
pipe = loaded_pipelines[model_id] | |
# Jalankan inference | |
try: | |
result = pipe(request.prompt) | |
return {"model_used": model_id, "task": task, "input_prompt": request.prompt, "output": result} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Gagal melakukan inference: {str(e)}") |