Xylor commited on
Commit
ae695f0
·
verified ·
1 Parent(s): a5fcb71

Swap model to 'jinaai/jina-reranker-m0'

Browse files

Added new classifier funktion to work with "simple" AutoModel.compute_score instead of zero-shot-pipeline

Files changed (1) hide show
  1. app.py +33 -9
app.py CHANGED
@@ -6,9 +6,16 @@ from transformers import pipeline
6
  logger = logging.getLogger("gradio_test_001")
7
  logger.setLevel(logging.INFO)
8
  logging.debug("Starting logging for gradio_test_001.")
 
 
 
 
 
 
 
9
 
10
- classifier = pipeline("zero-shot-classification",
11
- model="facebook/bart-large-mnli")
12
 
13
  # sequence_to_classify = "one day I will see the world"
14
  # candidate_labels = ['travel', 'cooking', 'dancing']
@@ -16,12 +23,6 @@ classifier = pipeline("zero-shot-classification",
16
  # 'doc_type.Scheme', 'content_type.Alt', 'content_type.Krypto',
17
  # 'content_type.Karte', 'content_type.Banking', 'content_type.Reg',
18
  # 'content_type.Konto']
19
- categories = [
20
- "Legal", "Specification", "Facts and Figures",
21
- "Publication", "Payment Scheme",
22
- "Alternative Payment Systems", "Crypto Payments",
23
- "Card Payments", "Banking", "Regulations", "Account Payments"
24
- ]
25
 
26
  def transform_output(res: dict) -> list:
27
  return list(
@@ -53,13 +54,36 @@ def clf_text(txt: str | list[str]):
53
  # 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],
54
  # 'sequence': 'one day I will see the world'}
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def my_inference_function(name):
58
  return "Hello " + name + "!"
59
 
60
  gradio_interface = gradio.Interface(
61
  # fn = my_inference_function,
62
- fn = clf_text,
 
63
  inputs = "text",
64
  outputs = gradio.JSON()
65
  )
 
6
  logger = logging.getLogger("gradio_test_001")
7
  logger.setLevel(logging.INFO)
8
  logging.debug("Starting logging for gradio_test_001.")
9
+ categories = [
10
+ "Legal", "Specification", "Facts and Figures",
11
+ "Publication", "Payment Scheme",
12
+ "Alternative Payment Systems", "Crypto Payments",
13
+ "Card Payments", "Banking", "Regulations", "Account Payments"
14
+ ]
15
+ logging.debug("Categories to classify: " + repr(categories))
16
 
17
+ # classifier = pipeline("zero-shot-classification",
18
+ # model="facebook/bart-large-mnli")
19
 
20
  # sequence_to_classify = "one day I will see the world"
21
  # candidate_labels = ['travel', 'cooking', 'dancing']
 
23
  # 'doc_type.Scheme', 'content_type.Alt', 'content_type.Krypto',
24
  # 'content_type.Karte', 'content_type.Banking', 'content_type.Reg',
25
  # 'content_type.Konto']
 
 
 
 
 
 
26
 
27
  def transform_output(res: dict) -> list:
28
  return list(
 
54
  # 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],
55
  # 'sequence': 'one day I will see the world'}
56
 
57
+ from transformers import AutoModel
58
+ # comment out the flash_attention_2 line if you don't have a compatible GPU
59
+ model = AutoModel.from_pretrained(
60
+ 'jinaai/jina-reranker-m0',
61
+ torch_dtype="auto",
62
+ trust_remote_code=True,
63
+ # attn_implementation="flash_attention_2"
64
+ )
65
+
66
+ def clf_jina(txt: str | list[str]):
67
+ # construct sentence pairs
68
+ # text_pairs = [[query, doc] for doc in documents]
69
+ text_pairs = [[cat, txt] for cat in categories]
70
+ scores = model.compute_score(text_pairs, max_length=1024, doc_type="text")
71
+ return list(
72
+ sorted(
73
+ zip(categories, scores),
74
+ key=lambda tpl: tpl[1],
75
+ reverse=True
76
+ )
77
+ )
78
+
79
 
80
  def my_inference_function(name):
81
  return "Hello " + name + "!"
82
 
83
  gradio_interface = gradio.Interface(
84
  # fn = my_inference_function,
85
+ # fn = clf_text,
86
+ clf_jina,
87
  inputs = "text",
88
  outputs = gradio.JSON()
89
  )