File size: 9,413 Bytes
d2346a3
 
90260b6
 
 
 
d2346a3
90260b6
 
d2346a3
90260b6
d2346a3
 
90260b6
 
d2346a3
 
 
90260b6
 
 
 
 
 
 
 
 
d2346a3
90260b6
 
 
 
d2346a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90260b6
d2346a3
 
 
 
90260b6
d2346a3
 
 
 
90260b6
d2346a3
90260b6
d2346a3
 
 
90260b6
d2346a3
90260b6
 
 
d2346a3
90260b6
d2346a3
90260b6
d2346a3
 
 
90260b6
 
 
d2346a3
 
 
 
 
 
90260b6
d2346a3
90260b6
d2346a3
 
 
90260b6
d2346a3
 
 
90260b6
 
d2346a3
 
 
 
 
 
 
 
 
 
 
90260b6
d2346a3
 
 
90260b6
 
d2346a3
 
90260b6
d2346a3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90260b6
d2346a3
 
 
 
 
 
 
 
 
 
 
 
 
 
90260b6
d2346a3
 
 
90260b6
 
 
d2346a3
 
90260b6
d2346a3
90260b6
d2346a3
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# app.py
from flask import Flask, jsonify, request
from flask_cors import CORS
from transformers import pipeline
import logging
import torch
import os # Untuk mendapatkan environment variables, misalnya di Hugging Face Spaces

app = Flask(__name__)
CORS(app) # Mengaktifkan CORS untuk mengizinkan permintaan dari frontend Anda

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# --- Konfigurasi Model dan Informasi ---
# Tambahkan 'hf_model_name' jika nama model di Hugging Face berbeda dari ID yang Anda inginkan.
# Jika nama model di Hugging Face sama, tidak perlu 'hf_model_name'.
model_info = {
    "Albert-Base-V2": {"task": "fill-mask", "description": "BERT-based model for masked language modeling"},
    "GPT-2": {"task": "text-generation", "description": "GPT-2 model for text generation"},
    "Tinny-Llama": {"task": "text-generation", "description": "Lightweight LLaMA model"},
    "Electra-Small": {"task": "fill-mask", "description": "Small ELECTRA model"},
    "GPT-2-Tinny": {"task": "text-generation", "description": "Tiny GPT-2 variant"},
    "Bert-Tinny": {"task": "fill-mask", "description": "Tiny BERT model"},
    "Distilbert-Base-Uncased": {"task": "fill-mask", "description": "Distilled BERT model"},
    "Pythia": {"task": "text-generation", "description": "Pythia language model"},
    "T5-Small": {"task": "text2text-generation", "description": "Small T5 model", "hf_model_name": "t5-small"},
    "GPT-Neo": {"task": "text-generation", "description": "GPT-Neo model"},
    "Distil-GPT-2": {"task": "text-generation", "description": "Distilled GPT-2 model"}
}

# --- Penyimpanan Model Global (untuk Lazy Loading) ---
models = {}

# --- Fungsi Utility untuk Model Lazy Loading ---
def get_model_pipeline(model_name):
    """
    Memuat model hanya jika belum dimuat (lazy loading).
    Mengembalikan pipeline model yang diminta.
    """
    if model_name not in models:
        logger.info(f"Model '{model_name}' belum dimuat. Memuat sekarang...")
        if model_name not in model_info:
            logger.error(f"Informasi model '{model_name}' tidak ditemukan di model_info.")
            raise ValueError(f"Model '{model_name}' tidak dikenal.")

        info = model_info[model_name]
        try:
            # Gunakan 'hf_model_name' jika disediakan, jika tidak, gunakan model_name dengan prefix 'Lyon28/'
            hf_model_path = info.get("hf_model_name", f"Lyon28/{model_name}")

            # Explicitly set device to "cpu" for CPU-only environments
            models[model_name] = pipeline(
                info["task"],
                model=hf_model_path,
                device="cpu", # Penting: Pastikan ini "cpu" jika Anda tidak punya GPU
                torch_dtype=torch.float32 # Tetap float32 untuk performa terbaik di CPU
            )
            logger.info(f"βœ… Model '{model_name}' (Path: {hf_model_path}) berhasil dimuat.")
        except Exception as e:
            logger.error(f"❌ Gagal memuat model '{model_name}' (Path: {hf_model_path}): {str(e)}", exc_info=True)
            raise RuntimeError(f"Gagal memuat model: {model_name}. Detail: {str(e)}") from e
    return models[model_name]

# --- Rute API ---

@app.route('/')
def home():
    """Endpoint root untuk status API."""
    return jsonify({
        "message": "Flask API untuk Model Hugging Face",
        "status": "online",
        "loaded_models_count": len(models),
        "available_model_configs": list(model_info.keys()),
        "info": "Gunakan /api/models untuk daftar model yang tersedia."
    })

@app.route('/api/models', methods=['GET'])
def list_available_models():
    """Mengembalikan daftar semua model yang dikonfigurasi, termasuk status muatan."""
    available_models_data = [
        {
            "id": name,
            "name": info["description"],
            "task": info["task"],
            "status": "loaded" if name in models else "not_loaded", # Menunjukkan apakah sudah dimuat via lazy loading
            "endpoint": f"/api/{name}"
        }
        for name, info in model_info.items()
    ]
    return jsonify({
        "total_configured_models": len(model_info),
        "currently_loaded_models": len(models),
        "models": available_models_data
    })

@app.route('/api/<model_id>', methods=['POST'])
def predict_with_model(model_id):
    """
    Endpoint utama untuk prediksi model.
    Menerima 'inputs' (teks) dan 'parameters' (dictionary) opsional.
    """
    logger.info(f"Menerima permintaan untuk model: {model_id}")
    if model_id not in model_info:
        logger.warning(f"Permintaan untuk model tidak dikenal: {model_id}")
        return jsonify({"error": f"Model '{model_id}' tidak dikenal. Lihat /api/models untuk daftar yang tersedia."}), 404

    try:
        model_pipeline = get_model_pipeline(model_id) # Memuat model jika belum ada
        model_task = model_info[model_id]["task"]

        data = request.json
        inputs = data.get('inputs', '')
        parameters = data.get('parameters', {}) # Default ke dictionary kosong jika tidak ada

        if not inputs:
            return jsonify({"error": "Input 'inputs' tidak boleh kosong."}), 400

        logger.info(f"Inferensi: Model='{model_id}', Task='{model_task}', Input='{inputs[:100]}...', Params='{parameters}'")

        result = []
        # --- Penanganan Parameter dan Inferensi berdasarkan Tipe Tugas ---
        if model_task == "text-generation":
            # Default parameters for text-generation
            gen_params = {
                "max_new_tokens": parameters.get("max_new_tokens", 150), # Lebih banyak token untuk roleplay
                "temperature": parameters.get("temperature", 0.7),
                "do_sample": parameters.get("do_sample", True),
                "return_full_text": parameters.get("return_full_text", False), # Sangat penting untuk chatbot
                "num_return_sequences": parameters.get("num_return_sequences", 1),
                "top_k": parameters.get("top_k", 50),
                "top_p": parameters.get("top_p", 0.95),
                "repetition_penalty": parameters.get("repetition_penalty", 1.2), # Mencegah pengulangan
            }
            result = model_pipeline(inputs, **gen_params)

        elif model_task == "fill-mask":
            mask_params = {
                "top_k": parameters.get("top_k", 5)
            }
            result = model_pipeline(inputs, **mask_params)

        elif model_task == "text2text-generation": # Misalnya untuk T5
            t2t_params = {
                "max_new_tokens": parameters.get("max_new_tokens", 150),
                "temperature": parameters.get("temperature", 0.7),
                "do_sample": parameters.get("do_sample", True),
            }
            result = model_pipeline(inputs, **t2t_params)

        else:
            # Fallback for other tasks or if no specific parameters are needed
            result = model_pipeline(inputs, **parameters)

        # --- Konsistensi Format Output ---
        response_output = {}
        if model_task == "text-generation" or model_task == "text2text-generation":
            if result and len(result) > 0 and 'generated_text' in result[0]:
                response_output['text'] = result[0]['generated_text'].strip()
            else:
                response_output['text'] = "[Tidak ada teks yang dihasilkan atau format tidak sesuai.]"
        elif model_task == "fill-mask":
            response_output['predictions'] = [
                {"sequence": p.get('sequence', ''), "score": p.get('score', 0.0), "token_str": p.get('token_str', '')}
                for p in result
            ]
        else:
            # Untuk jenis tugas lain, kembalikan hasil mentah
            response_output = result

        logger.info(f"Inferensi berhasil untuk '{model_id}'. Output singkat: '{str(response_output)[:200]}'")
        return jsonify({"model": model_id, "inputs": inputs, "outputs": response_output})

    except ValueError as ve:
        # Error yang berasal dari get_model_pipeline atau validasi input
        logger.error(f"Validasi atau konfigurasi error untuk model '{model_id}': {str(ve)}")
        return jsonify({"error": str(ve), "message": "Kesalahan konfigurasi atau input model."}), 400
    except RuntimeError as re:
        # Error saat memuat model
        logger.error(f"Error runtime saat memuat model '{model_id}': {str(re)}")
        return jsonify({"error": str(re), "message": "Model gagal dimuat."}), 503 # Service Unavailable
    except Exception as e:
        # Catch all other unexpected errors during prediction
        logger.error(f"Terjadi kesalahan tak terduga saat memprediksi dengan model '{model_id}': {str(e)}", exc_info=True)
        return jsonify({"error": str(e), "message": "Terjadi kesalahan internal server."}), 500

@app.route('/health', methods=['GET'])
def health_check():
    """Endpoint untuk health check."""
    return jsonify({"status": "healthy", "loaded_models_count": len(models), "message": "API berfungsi normal."})

# --- Jalankan Aplikasi ---
if __name__ == '__main__':
    # Untuk Hugging Face Spaces, port biasanya 7860
    # Menggunakan HOST dari environment variable jika tersedia, default ke 0.0.0.0
    # Debug=False untuk produksi
    app.run(host=os.getenv('HOST', '0.0.0.0'), port=int(os.getenv('PORT', 7860)), debug=False)