Spaces:
Running
Running
import gradio | |
import logging | |
import time | |
from transformers import pipeline | |
logger = logging.getLogger("gradio_test_001") | |
logger.setLevel(logging.INFO) | |
logging.debug("Starting logging for gradio_test_001.") | |
categories = [ | |
"Legal", "Specification", "Facts and Figures", | |
"Publication", "Payment Scheme", | |
"Alternative Payment Systems", "Crypto Payments", | |
"Card Payments", "Banking", "Regulations", "Account Payments" | |
] | |
logging.debug("Categories to classify: " + repr(categories)) | |
# classifier = pipeline("zero-shot-classification", | |
# model="facebook/bart-large-mnli") | |
# sequence_to_classify = "one day I will see the world" | |
# candidate_labels = ['travel', 'cooking', 'dancing'] | |
# CATEGORIES = ['doc_type.jur', 'doc_type.Spec', 'doc_type.ZDF', 'doc_type.Publ', | |
# 'doc_type.Scheme', 'content_type.Alt', 'content_type.Krypto', | |
# 'content_type.Karte', 'content_type.Banking', 'content_type.Reg', | |
# 'content_type.Konto'] | |
def transform_output(res: dict) -> list: | |
return list( | |
sorted( | |
zip(res["labels"], res["scores"]), | |
key=lambda tpl: tpl[1], | |
reverse=True | |
) | |
) | |
def clf_text(txt: str | list[str]): | |
logger.info("Classify: " + repr(txt)) | |
t0 = time.time() | |
res = classifier(txt, categories, multi_label=True) | |
elapsed = time.time() - t0 | |
logger.info(f"Done. {elapsed:.02f}s") | |
logger.info(f"Result(s): " + repr(res)) | |
if isinstance(res, list): | |
return [ transform_output(dct) for dct in res ] | |
else: | |
return transform_output(res) | |
# items = sorted(zip(res["labels"], res["scores"]), key=lambda tpl: tpl[1], reverse=True) | |
# d = dict(zip(res["labels"], res["scores"])) | |
# output = [f"{lbl}:\t{score}" for lbl, score in items] | |
# return "\n".join(output) | |
# return list(items) | |
# classifier(sequence_to_classify, candidate_labels) | |
#{'labels': ['travel', 'dancing', 'cooking'], | |
# 'scores': [0.9938651323318481, 0.0032737774308770895, 0.002861034357920289], | |
# 'sequence': 'one day I will see the world'} | |
from transformers import AutoModel | |
# comment out the flash_attention_2 line if you don't have a compatible GPU | |
model = AutoModel.from_pretrained( | |
'jinaai/jina-reranker-m0', | |
torch_dtype="auto", | |
trust_remote_code=True, | |
# attn_implementation="flash_attention_2" | |
) | |
def clf_jina(txt: str | list[str]): | |
# construct sentence pairs | |
# text_pairs = [[query, doc] for doc in documents] | |
text_pairs = [[cat, txt] for cat in categories] | |
scores = model.compute_score(text_pairs, max_length=1024, doc_type="text") | |
return list( | |
sorted( | |
zip(categories, scores), | |
key=lambda tpl: tpl[1], | |
reverse=True | |
) | |
) | |
def my_inference_function(name): | |
return "Hello " + name + "!" | |
gradio_interface = gradio.Interface( | |
# fn = my_inference_function, | |
# fn = clf_text, | |
clf_jina, | |
inputs = "text", | |
outputs = gradio.JSON() | |
) | |
logger.debug("Launch app.") | |
gradio_interface.launch() |