File size: 1,295 Bytes
9746697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2de65f4
 
9746697
 
2de65f4
9746697
 
2de65f4
 
 
9746697
 
 
2de65f4
9746697
2de65f4
9746697
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from html import unescape
from unicodedata import normalize
import gradio as gr
from transformers import pipeline
import re

re_multispace = re.compile(r"\s+")

def normalize_text(text):
    if text == None:
        return None

    text = text.strip()
    text = text.replace("\n", " ")
    text = text.replace("\t", " ")
    text = text.replace("\r", " ")
    text = re_multispace.sub(" ", text)
    text = unescape(text)
    text = normalize("NFKC", text)
    return text


models = [
    "Server", "Category", "Gender", "Day Of Week"
]

pipelines = {model: pipeline(task="text-classification",
 model=f"hynky/{model.replace(' ', '_')}", tokenizer="ufal/robeczech-base",
 truncation=True, max_length=512,
 top_k=5
) for model in models}


def predict(article):
    article = normalize_text(article)
    predictions = [pipelines[model](article)[0] for model in models]
    predictions = [{pred["label"]: round(pred["score"], 3) for pred in task_preds} for task_preds in predictions]
    return tuple(predictions)

gr.Interface(
    predict,
    inputs=gr.inputs.Textbox(lines=4, placeholder="Paste a news article here..."),
    # multioutput of gradio text
    outputs=[gr.outputs.Label(num_top_classes=5, label=model)
    for model in models],
    title="News Article Classifier",
).launch()