chen666-666 commited on
Commit
4cd7a63
·
1 Parent(s): 1d3964d

add app.py and requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +29 -30
app.py CHANGED
@@ -240,15 +240,15 @@ def convert_telegram_json_to_eval_format(path):
240
 
241
 
242
  def evaluate_ner_model(data, model_type):
243
- y_true, y_pred = [], []
244
- POS_TOLERANCE = 1 # 允许的位置误差
245
 
246
  for item in data:
247
  text = item["text"]
 
248
  gold_entities = []
249
  for e in item.get("entities", []):
250
  if "text" in e and "type" in e:
251
- # 标准化标签
252
  norm_type = LABEL_MAPPING.get(e["type"], e["type"])
253
  gold_entities.append({
254
  "text": e["text"],
@@ -257,39 +257,38 @@ def evaluate_ner_model(data, model_type):
257
  "end": e.get("end", -1)
258
  })
259
 
 
260
  pred_entities, _ = ner(text, model_type)
261
 
262
- # 构建对比集合
263
- all_entities = set()
264
- # 处理标注数据
265
- for g in gold_entities:
266
- key = f"{g['text']}|{g['type']}|{g['start']}|{g['end']}"
267
- all_entities.add(key)
268
-
269
- # 处理预测结果
270
- pred_set = set()
271
- for p in pred_entities:
272
- # 允许位置误差
273
- matched = False
274
- for g in gold_entities:
275
- if (p["text"] == g["text"] and
276
- p["type"] == g["type"] and
277
- abs(p["start"] - g["start"]) <= POS_TOLERANCE and
278
- abs(p["end"] - g["end"]) <= POS_TOLERANCE):
279
- matched = True
280
  break
281
- pred_set.add(matched)
282
 
283
- # 构建指标
284
- y_true.extend([1] * len(gold_entities))
285
- y_pred.extend([1 if m else 0 for m in pred_set])
 
286
 
287
- if not y_true:
288
- return "⚠️ 无有效标注数据"
 
 
289
 
290
- return (f"Precision: {precision_score(y_true, y_pred, zero_division=0):.2f}\n"
291
- f"Recall: {recall_score(y_true, y_pred, zero_division=0):.2f}\n"
292
- f"F1: {f1_score(y_true, y_pred, zero_division=0):.2f}")
293
 
294
  def auto_annotate(file, model_type):
295
  data = convert_telegram_json_to_eval_format(file.name)
 
240
 
241
 
242
  def evaluate_ner_model(data, model_type):
243
+ tp, fp, fn = 0, 0, 0
244
+ POS_TOLERANCE = 1
245
 
246
  for item in data:
247
  text = item["text"]
248
+ # 处理标注数据
249
  gold_entities = []
250
  for e in item.get("entities", []):
251
  if "text" in e and "type" in e:
 
252
  norm_type = LABEL_MAPPING.get(e["type"], e["type"])
253
  gold_entities.append({
254
  "text": e["text"],
 
257
  "end": e.get("end", -1)
258
  })
259
 
260
+ # 获取预测结果
261
  pred_entities, _ = ner(text, model_type)
262
 
263
+ # 初始化匹配状态
264
+ matched_gold = [False] * len(gold_entities)
265
+ matched_pred = [False] * len(pred_entities)
266
+
267
+ # 遍历预测实体寻找匹配
268
+ for p_idx, p in enumerate(pred_entities):
269
+ for g_idx, g in enumerate(gold_entities):
270
+ if not matched_gold[g_idx] and \
271
+ p["text"] == g["text"] and \
272
+ p["type"] == g["type"] and \
273
+ abs(p["start"] - g["start"]) <= POS_TOLERANCE and \
274
+ abs(p["end"] - g["end"]) <= POS_TOLERANCE:
275
+ matched_gold[g_idx] = True
276
+ matched_pred[p_idx] = True
 
 
 
 
277
  break
 
278
 
279
+ # 统计指标
280
+ tp += sum(matched_pred)
281
+ fp += len(pred_entities) - sum(matched_pred)
282
+ fn += len(gold_entities) - sum(matched_gold)
283
 
284
+ # 处理除零情况
285
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0
286
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0
287
+ f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
288
 
289
+ return (f"Precision: {precision:.2f}\n"
290
+ f"Recall: {recall:.2f}\n"
291
+ f"F1: {f1:.2f}")
292
 
293
  def auto_annotate(file, model_type):
294
  data = convert_telegram_json_to_eval_format(file.name)