Xylor commited on
Commit
dbbcf45
·
verified ·
1 Parent(s): 06e54c8

Changed clf_text signature and moved clf output transformation to its own function

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -17,13 +17,26 @@ categories = [
17
  "Card Payments", "Banking", "Regulations", "Account Payments"
18
  ]
19
 
20
- def clf_text(txt: str):
 
 
 
 
 
 
 
 
 
21
  res = classifier(txt, categories, multi_label=True)
22
- items = sorted(zip(res["labels"], res["scores"]), key=lambda tpl: tpl[1], reverse=True)
 
 
 
 
23
  # d = dict(zip(res["labels"], res["scores"]))
24
  # output = [f"{lbl}:\t{score}" for lbl, score in items]
25
  # return "\n".join(output)
26
- return list(items)
27
  # classifier(sequence_to_classify, candidate_labels)
28
  #{'labels': ['travel', 'dancing', 'cooking'],
29
  # 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],
 
17
  "Card Payments", "Banking", "Regulations", "Account Payments"
18
  ]
19
 
20
+ def transform_output(res: dict) -> list:
21
+ return list(
22
+ sorted(
23
+ zip(res["labels"], res["scores"]),
24
+ key=lambda tpl: tpl[1],
25
+ reverse=True
26
+ )
27
+ )
28
+
29
+ def clf_text(txt: str | list[str]):
30
  res = classifier(txt, categories, multi_label=True)
31
+ if isinstance(res, list):
32
+ return [ transform_output(dct) for dct in res ]
33
+ else:
34
+ return transform_output(res)
35
+ # items = sorted(zip(res["labels"], res["scores"]), key=lambda tpl: tpl[1], reverse=True)
36
  # d = dict(zip(res["labels"], res["scores"]))
37
  # output = [f"{lbl}:\t{score}" for lbl, score in items]
38
  # return "\n".join(output)
39
+ # return list(items)
40
  # classifier(sequence_to_classify, candidate_labels)
41
  #{'labels': ['travel', 'dancing', 'cooking'],
42
  # 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],