Lyon28 commited on
Commit
3f29839
·
verified ·
1 Parent(s): 1a39cb8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -21
app.py CHANGED
@@ -125,7 +125,7 @@ def list_available_models():
125
  def predict_with_model(model_id):
126
  """
127
  Endpoint utama untuk prediksi model.
128
- Menerima 'inputs' (teks pra-diformat) dan 'parameters' (dictionary) opsional.
129
  """
130
  logger.info(f"Menerima permintaan untuk model: {model_id}")
131
  if model_id not in model_info:
@@ -137,51 +137,49 @@ def predict_with_model(model_id):
137
  model_task = model_info[model_id]["task"]
138
 
139
  data = request.json
140
- # Input sekarang diharapkan sebagai fullPromptString dari frontend
141
- full_prompt_string_from_frontend = data.get('inputs', '')
142
- parameters = data.get('parameters', {})
143
 
144
- if not full_prompt_string_from_frontend:
145
- return jsonify({"error": "Input 'inputs' (full prompt string) tidak boleh kosong."}), 400
146
 
147
- logger.info(f"Inferensi: Model='{model_id}', Task='{model_task}', Full Prompt='{full_prompt_string_from_frontend[:200]}...', Params='{parameters}'")
148
 
149
  result = []
150
  # --- Penanganan Parameter dan Inferensi berdasarkan Tipe Tugas ---
151
  if model_task == "text-generation":
 
152
  gen_params = {
153
- "max_new_tokens": parameters.get("max_new_tokens", 150),
154
  "temperature": parameters.get("temperature", 0.7),
155
  "do_sample": parameters.get("do_sample", True),
156
  "return_full_text": parameters.get("return_full_text", False), # Sangat penting untuk chatbot
157
  "num_return_sequences": parameters.get("num_return_sequences", 1),
158
  "top_k": parameters.get("top_k", 50),
159
  "top_p": parameters.get("top_p", 0.95),
160
- "repetition_penalty": parameters.get("repetition_penalty", 1.2),
161
  }
162
- # Langsung berikan full_prompt_string_from_frontend ke pipeline
163
- result = model_pipeline(full_prompt_string_from_frontend, **gen_params)
164
 
165
  elif model_task == "fill-mask":
166
  mask_params = {
167
  "top_k": parameters.get("top_k", 5)
168
  }
169
- # Untuk fill-mask, input harus string biasa, bukan prompt yang kompleks
170
- # Anda perlu memastikan frontend tidak mengirim prompt kompleks ke fill-mask model
171
- result = model_pipeline(full_prompt_string_from_frontend, **mask_params)
172
 
173
- elif model_task == "text2text-generation":
174
  t2t_params = {
175
  "max_new_tokens": parameters.get("max_new_tokens", 150),
176
  "temperature": parameters.get("temperature", 0.7),
177
  "do_sample": parameters.get("do_sample", True),
178
  }
179
- result = model_pipeline(full_prompt_string_from_frontend, **t2t_params)
180
 
181
  else:
182
- result = model_pipeline(full_prompt_string_from_frontend, **parameters)
 
183
 
184
- # --- Konsistensi Format Output (tidak berubah dari update sebelumnya) ---
185
  response_output = {}
186
  if model_task == "text-generation" or model_task == "text2text-generation":
187
  if result and len(result) > 0 and 'generated_text' in result[0]:
@@ -194,18 +192,22 @@ def predict_with_model(model_id):
194
  for p in result
195
  ]
196
  else:
 
197
  response_output = result
198
 
199
  logger.info(f"Inferensi berhasil untuk '{model_id}'. Output singkat: '{str(response_output)[:200]}'")
200
- return jsonify({"model": model_id, "inputs": full_prompt_string_from_frontend, "outputs": response_output})
201
 
202
  except ValueError as ve:
 
203
  logger.error(f"Validasi atau konfigurasi error untuk model '{model_id}': {str(ve)}")
204
  return jsonify({"error": str(ve), "message": "Kesalahan konfigurasi atau input model."}), 400
205
  except RuntimeError as re:
 
206
  logger.error(f"Error runtime saat memuat model '{model_id}': {str(re)}")
207
- return jsonify({"error": str(re), "message": "Model gagal dimuat."}), 503
208
  except Exception as e:
 
209
  logger.error(f"Terjadi kesalahan tak terduga saat memprediksi dengan model '{model_id}': {str(e)}", exc_info=True)
210
  return jsonify({"error": str(e), "message": "Terjadi kesalahan internal server."}), 500
211
 
@@ -219,4 +221,4 @@ if __name__ == '__main__':
219
  # Untuk Hugging Face Spaces, port biasanya 7860
220
  # Menggunakan HOST dari environment variable jika tersedia, default ke 0.0.0.0
221
  # Debug=False untuk produksi
222
- app.run(host=os.getenv('HOST', '0.0.0.0'), port=int(os.getenv('PORT', 7860)), debug=False)
 
125
  def predict_with_model(model_id):
126
  """
127
  Endpoint utama untuk prediksi model.
128
+ Menerima 'inputs' (teks) dan 'parameters' (dictionary) opsional.
129
  """
130
  logger.info(f"Menerima permintaan untuk model: {model_id}")
131
  if model_id not in model_info:
 
137
  model_task = model_info[model_id]["task"]
138
 
139
  data = request.json
140
+ inputs = data.get('inputs', '')
141
+ parameters = data.get('parameters', {}) # Default ke dictionary kosong jika tidak ada
 
142
 
143
+ if not inputs:
144
+ return jsonify({"error": "Input 'inputs' tidak boleh kosong."}), 400
145
 
146
+ logger.info(f"Inferensi: Model='{model_id}', Task='{model_task}', Input='{inputs[:100]}...', Params='{parameters}'")
147
 
148
  result = []
149
  # --- Penanganan Parameter dan Inferensi berdasarkan Tipe Tugas ---
150
  if model_task == "text-generation":
151
+ # Default parameters for text-generation
152
  gen_params = {
153
+ "max_new_tokens": parameters.get("max_new_tokens", 150), # Lebih banyak token untuk roleplay
154
  "temperature": parameters.get("temperature", 0.7),
155
  "do_sample": parameters.get("do_sample", True),
156
  "return_full_text": parameters.get("return_full_text", False), # Sangat penting untuk chatbot
157
  "num_return_sequences": parameters.get("num_return_sequences", 1),
158
  "top_k": parameters.get("top_k", 50),
159
  "top_p": parameters.get("top_p", 0.95),
160
+ "repetition_penalty": parameters.get("repetition_penalty", 1.2), # Mencegah pengulangan
161
  }
162
+ result = model_pipeline(inputs, **gen_params)
 
163
 
164
  elif model_task == "fill-mask":
165
  mask_params = {
166
  "top_k": parameters.get("top_k", 5)
167
  }
168
+ result = model_pipeline(inputs, **mask_params)
 
 
169
 
170
+ elif model_task == "text2text-generation": # Misalnya untuk T5
171
  t2t_params = {
172
  "max_new_tokens": parameters.get("max_new_tokens", 150),
173
  "temperature": parameters.get("temperature", 0.7),
174
  "do_sample": parameters.get("do_sample", True),
175
  }
176
+ result = model_pipeline(inputs, **t2t_params)
177
 
178
  else:
179
+ # Fallback for other tasks or if no specific parameters are needed
180
+ result = model_pipeline(inputs, **parameters)
181
 
182
+ # --- Konsistensi Format Output ---
183
  response_output = {}
184
  if model_task == "text-generation" or model_task == "text2text-generation":
185
  if result and len(result) > 0 and 'generated_text' in result[0]:
 
192
  for p in result
193
  ]
194
  else:
195
+ # Untuk jenis tugas lain, kembalikan hasil mentah
196
  response_output = result
197
 
198
  logger.info(f"Inferensi berhasil untuk '{model_id}'. Output singkat: '{str(response_output)[:200]}'")
199
+ return jsonify({"model": model_id, "inputs": inputs, "outputs": response_output})
200
 
201
  except ValueError as ve:
202
+ # Error yang berasal dari get_model_pipeline atau validasi input
203
  logger.error(f"Validasi atau konfigurasi error untuk model '{model_id}': {str(ve)}")
204
  return jsonify({"error": str(ve), "message": "Kesalahan konfigurasi atau input model."}), 400
205
  except RuntimeError as re:
206
+ # Error saat memuat model
207
  logger.error(f"Error runtime saat memuat model '{model_id}': {str(re)}")
208
+ return jsonify({"error": str(re), "message": "Model gagal dimuat."}), 503 # Service Unavailable
209
  except Exception as e:
210
+ # Catch all other unexpected errors during prediction
211
  logger.error(f"Terjadi kesalahan tak terduga saat memprediksi dengan model '{model_id}': {str(e)}", exc_info=True)
212
  return jsonify({"error": str(e), "message": "Terjadi kesalahan internal server."}), 500
213
 
 
221
  # Untuk Hugging Face Spaces, port biasanya 7860
222
  # Menggunakan HOST dari environment variable jika tersedia, default ke 0.0.0.0
223
  # Debug=False untuk produksi
224
+ app.run(host=os.getenv('HOST', '0.0.0.0'), port=int(os.getenv('PORT', 7860)), debug=False)