File size: 3,009 Bytes
b87d52c
131144c
 
0c97a7e
131144c
a5fcb71
 
131144c
ae695f0
 
 
 
 
 
 
131144c
ae695f0
 
0c97a7e
 
 
 
 
 
 
 
dbbcf45
 
 
 
 
 
 
 
 
 
a5fcb71
131144c
9d20c0e
131144c
a5fcb71
 
dbbcf45
 
 
 
 
5fe3cd7
5e3ae64
 
dbbcf45
0c97a7e
 
 
 
 
ae695f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0c97a7e
b87d52c
 
 
 
0c97a7e
ae695f0
 
b87d52c
cbc4beb
b87d52c
131144c
b87d52c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio
import logging
import time
from transformers import pipeline

logger = logging.getLogger("gradio_test_001")
logger.setLevel(logging.INFO)
logging.debug("Starting logging for gradio_test_001.")
categories = [
    "Legal", "Specification", "Facts and Figures",
    "Publication", "Payment Scheme",
    "Alternative Payment Systems", "Crypto Payments",
    "Card Payments", "Banking", "Regulations", "Account Payments"
]
logging.debug("Categories to classify: " + repr(categories))

# classifier = pipeline("zero-shot-classification",
#                       model="facebook/bart-large-mnli")

# sequence_to_classify = "one day I will see the world"
# candidate_labels = ['travel', 'cooking', 'dancing']
# CATEGORIES = ['doc_type.jur', 'doc_type.Spec', 'doc_type.ZDF', 'doc_type.Publ',
#        'doc_type.Scheme', 'content_type.Alt', 'content_type.Krypto',
#        'content_type.Karte', 'content_type.Banking', 'content_type.Reg',
#        'content_type.Konto']

def transform_output(res: dict) -> list:
    return list(
        sorted(
            zip(res["labels"], res["scores"]),
            key=lambda tpl: tpl[1],
            reverse=True
        )
    )

def clf_text(txt: str | list[str]):
    logger.info("Classify: " + repr(txt))
    t0 = time.time()
    res = classifier(txt, categories, multi_label=True)
    elapsed = time.time() - t0
    logger.info(f"Done. {elapsed:.02f}s")
    logger.info(f"Result(s): " + repr(res))
    if isinstance(res, list):
        return [ transform_output(dct) for dct in res ]
    else:
        return transform_output(res)
    # items = sorted(zip(res["labels"], res["scores"]), key=lambda tpl: tpl[1], reverse=True)
    # d = dict(zip(res["labels"], res["scores"]))
    # output = [f"{lbl}:\t{score}" for lbl, score in items]
    # return "\n".join(output)
    # return list(items)
# classifier(sequence_to_classify, candidate_labels)
#{'labels': ['travel', 'dancing', 'cooking'],
# 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289],
# 'sequence': 'one day I will see the world'}

from transformers import AutoModel
# comment out the flash_attention_2 line if you don't have a compatible GPU
model = AutoModel.from_pretrained(
    'jinaai/jina-reranker-m0',
    torch_dtype="auto",
    trust_remote_code=True,
    # attn_implementation="flash_attention_2"
)

def clf_jina(txt: str | list[str]):
    # construct sentence pairs
    # text_pairs = [[query, doc] for doc in documents]
    text_pairs = [[cat, txt] for cat in categories]
    scores = model.compute_score(text_pairs, max_length=1024, doc_type="text")
    return list(
        sorted(
            zip(categories, scores),
            key=lambda tpl: tpl[1],
            reverse=True
        )
    )


def my_inference_function(name):
  return "Hello " + name + "!"

gradio_interface = gradio.Interface(
  # fn = my_inference_function,
  # fn = clf_text,
  clf_jina,
  inputs = "text",
  outputs = gradio.JSON()
)
logger.debug("Launch app.")
gradio_interface.launch()