Lyon28 commited on
Commit
ca2fcb8
·
verified ·
1 Parent(s): e922ac5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -23
app.py CHANGED
@@ -27,7 +27,33 @@ model_info = {
27
  "Pythia": {"task": "text-generation", "description": "Pythia language model"},
28
  "T5-Small": {"task": "text2text-generation", "description": "Small T5 model", "hf_model_name": "t5-small"},
29
  "GPT-Neo": {"task": "text-generation", "description": "GPT-Neo model"},
30
- "Distil-GPT-2": {"task": "text-generation", "description": "Distilled GPT-2 model"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  }
32
 
33
  # --- Penyimpanan Model Global (untuk Lazy Loading) ---
@@ -99,7 +125,7 @@ def list_available_models():
99
  def predict_with_model(model_id):
100
  """
101
  Endpoint utama untuk prediksi model.
102
- Menerima 'inputs' (teks) dan 'parameters' (dictionary) opsional.
103
  """
104
  logger.info(f"Menerima permintaan untuk model: {model_id}")
105
  if model_id not in model_info:
@@ -111,49 +137,51 @@ def predict_with_model(model_id):
111
  model_task = model_info[model_id]["task"]
112
 
113
  data = request.json
114
- inputs = data.get('inputs', '')
115
- parameters = data.get('parameters', {}) # Default ke dictionary kosong jika tidak ada
 
116
 
117
- if not inputs:
118
- return jsonify({"error": "Input 'inputs' tidak boleh kosong."}), 400
119
 
120
- logger.info(f"Inferensi: Model='{model_id}', Task='{model_task}', Input='{inputs[:100]}...', Params='{parameters}'")
121
 
122
  result = []
123
  # --- Penanganan Parameter dan Inferensi berdasarkan Tipe Tugas ---
124
  if model_task == "text-generation":
125
- # Default parameters for text-generation
126
  gen_params = {
127
- "max_new_tokens": parameters.get("max_new_tokens", 150), # Lebih banyak token untuk roleplay
128
  "temperature": parameters.get("temperature", 0.7),
129
  "do_sample": parameters.get("do_sample", True),
130
  "return_full_text": parameters.get("return_full_text", False), # Sangat penting untuk chatbot
131
  "num_return_sequences": parameters.get("num_return_sequences", 1),
132
  "top_k": parameters.get("top_k", 50),
133
  "top_p": parameters.get("top_p", 0.95),
134
- "repetition_penalty": parameters.get("repetition_penalty", 1.2), # Mencegah pengulangan
135
  }
136
- result = model_pipeline(inputs, **gen_params)
 
137
 
138
  elif model_task == "fill-mask":
139
  mask_params = {
140
  "top_k": parameters.get("top_k", 5)
141
  }
142
- result = model_pipeline(inputs, **mask_params)
 
 
143
 
144
- elif model_task == "text2text-generation": # Misalnya untuk T5
145
  t2t_params = {
146
  "max_new_tokens": parameters.get("max_new_tokens", 150),
147
  "temperature": parameters.get("temperature", 0.7),
148
  "do_sample": parameters.get("do_sample", True),
149
  }
150
- result = model_pipeline(inputs, **t2t_params)
151
 
152
  else:
153
- # Fallback for other tasks or if no specific parameters are needed
154
- result = model_pipeline(inputs, **parameters)
155
 
156
- # --- Konsistensi Format Output ---
157
  response_output = {}
158
  if model_task == "text-generation" or model_task == "text2text-generation":
159
  if result and len(result) > 0 and 'generated_text' in result[0]:
@@ -166,22 +194,18 @@ def predict_with_model(model_id):
166
  for p in result
167
  ]
168
  else:
169
- # Untuk jenis tugas lain, kembalikan hasil mentah
170
  response_output = result
171
 
172
  logger.info(f"Inferensi berhasil untuk '{model_id}'. Output singkat: '{str(response_output)[:200]}'")
173
- return jsonify({"model": model_id, "inputs": inputs, "outputs": response_output})
174
 
175
  except ValueError as ve:
176
- # Error yang berasal dari get_model_pipeline atau validasi input
177
  logger.error(f"Validasi atau konfigurasi error untuk model '{model_id}': {str(ve)}")
178
  return jsonify({"error": str(ve), "message": "Kesalahan konfigurasi atau input model."}), 400
179
  except RuntimeError as re:
180
- # Error saat memuat model
181
  logger.error(f"Error runtime saat memuat model '{model_id}': {str(re)}")
182
- return jsonify({"error": str(re), "message": "Model gagal dimuat."}), 503 # Service Unavailable
183
  except Exception as e:
184
- # Catch all other unexpected errors during prediction
185
  logger.error(f"Terjadi kesalahan tak terduga saat memprediksi dengan model '{model_id}': {str(e)}", exc_info=True)
186
  return jsonify({"error": str(e), "message": "Terjadi kesalahan internal server."}), 500
187
 
 
27
  "Pythia": {"task": "text-generation", "description": "Pythia language model"},
28
  "T5-Small": {"task": "text2text-generation", "description": "Small T5 model", "hf_model_name": "t5-small"},
29
  "GPT-Neo": {"task": "text-generation", "description": "GPT-Neo model"},
30
+ "Distil-GPT-2": {"task": "text-generation", "description": "Distilled GPT-2 model"},
31
+ # --- MODEL EXTERNAL ---
32
+ "Gemma-2B-IT": { # ID yang Anda inginkan di API Anda
33
+ "task": "text-generation",
34
+ "description": "Google's Gemma 2B Instruct model",
35
+ "hf_model_name": "google/gemma-2b-it"
36
+ },
37
+ "Mistral-7B-Instruct": {
38
+ "task": "text-generation",
39
+ "description": "Mistral AI's Mistral 7B Instruct model",
40
+ "hf_model_name": "mistralai/Mistral-7B-Instruct-v0.3",
41
+ }
42
+ "Qwen3-4B-RPG": {
43
+ "task": "text-generation",
44
+ "description": "Chun121's Qwen 4B RPG Roleplay model (Uncensored)",
45
+ "hf_model_name": "Chun121/qwen3-4B-rpg-roleplay"
46
+ },
47
+ "Llama-3.2-Uncensored-3B": {
48
+ "task": "text-generation",
49
+ "description": "Dhirajlochib's Llama 3.2 Uncensored 3B",
50
+ "hf_model_name": "dhirajlochib/llama-3.2-unsensored-3b"
51
+ },
52
+ "TinyLLama-NSFW-Chatbot": {
53
+ "task": "text-generation",
54
+ "description": "BilalRahib's TinyLLama NSFW Chatbot",
55
+ "hf_model_name": "bilalRahib/TinyLLama-NSFW-Chatbot"
56
+ }
57
  }
58
 
59
  # --- Penyimpanan Model Global (untuk Lazy Loading) ---
 
125
  def predict_with_model(model_id):
126
  """
127
  Endpoint utama untuk prediksi model.
128
+ Menerima 'inputs' (teks pra-diformat) dan 'parameters' (dictionary) opsional.
129
  """
130
  logger.info(f"Menerima permintaan untuk model: {model_id}")
131
  if model_id not in model_info:
 
137
  model_task = model_info[model_id]["task"]
138
 
139
  data = request.json
140
+ # Input sekarang diharapkan sebagai fullPromptString dari frontend
141
+ full_prompt_string_from_frontend = data.get('inputs', '')
142
+ parameters = data.get('parameters', {})
143
 
144
+ if not full_prompt_string_from_frontend:
145
+ return jsonify({"error": "Input 'inputs' (full prompt string) tidak boleh kosong."}), 400
146
 
147
+ logger.info(f"Inferensi: Model='{model_id}', Task='{model_task}', Full Prompt='{full_prompt_string_from_frontend[:200]}...', Params='{parameters}'")
148
 
149
  result = []
150
  # --- Penanganan Parameter dan Inferensi berdasarkan Tipe Tugas ---
151
  if model_task == "text-generation":
 
152
  gen_params = {
153
+ "max_new_tokens": parameters.get("max_new_tokens", 150),
154
  "temperature": parameters.get("temperature", 0.7),
155
  "do_sample": parameters.get("do_sample", True),
156
  "return_full_text": parameters.get("return_full_text", False), # Sangat penting untuk chatbot
157
  "num_return_sequences": parameters.get("num_return_sequences", 1),
158
  "top_k": parameters.get("top_k", 50),
159
  "top_p": parameters.get("top_p", 0.95),
160
+ "repetition_penalty": parameters.get("repetition_penalty", 1.2),
161
  }
162
+ # Langsung berikan full_prompt_string_from_frontend ke pipeline
163
+ result = model_pipeline(full_prompt_string_from_frontend, **gen_params)
164
 
165
  elif model_task == "fill-mask":
166
  mask_params = {
167
  "top_k": parameters.get("top_k", 5)
168
  }
169
+ # Untuk fill-mask, input harus string biasa, bukan prompt yang kompleks
170
+ # Anda perlu memastikan frontend tidak mengirim prompt kompleks ke fill-mask model
171
+ result = model_pipeline(full_prompt_string_from_frontend, **mask_params)
172
 
173
+ elif model_task == "text2text-generation":
174
  t2t_params = {
175
  "max_new_tokens": parameters.get("max_new_tokens", 150),
176
  "temperature": parameters.get("temperature", 0.7),
177
  "do_sample": parameters.get("do_sample", True),
178
  }
179
+ result = model_pipeline(full_prompt_string_from_frontend, **t2t_params)
180
 
181
  else:
182
+ result = model_pipeline(full_prompt_string_from_frontend, **parameters)
 
183
 
184
+ # --- Konsistensi Format Output (tidak berubah dari update sebelumnya) ---
185
  response_output = {}
186
  if model_task == "text-generation" or model_task == "text2text-generation":
187
  if result and len(result) > 0 and 'generated_text' in result[0]:
 
194
  for p in result
195
  ]
196
  else:
 
197
  response_output = result
198
 
199
  logger.info(f"Inferensi berhasil untuk '{model_id}'. Output singkat: '{str(response_output)[:200]}'")
200
+ return jsonify({"model": model_id, "inputs": full_prompt_string_from_frontend, "outputs": response_output})
201
 
202
  except ValueError as ve:
 
203
  logger.error(f"Validasi atau konfigurasi error untuk model '{model_id}': {str(ve)}")
204
  return jsonify({"error": str(ve), "message": "Kesalahan konfigurasi atau input model."}), 400
205
  except RuntimeError as re:
 
206
  logger.error(f"Error runtime saat memuat model '{model_id}': {str(re)}")
207
+ return jsonify({"error": str(re), "message": "Model gagal dimuat."}), 503
208
  except Exception as e:
 
209
  logger.error(f"Terjadi kesalahan tak terduga saat memprediksi dengan model '{model_id}': {str(e)}", exc_info=True)
210
  return jsonify({"error": str(e), "message": "Terjadi kesalahan internal server."}), 500
211