File size: 1,821 Bytes
a452638
 
dbf4eb2
7c7d7c6
57fe04c
 
 
dbf4eb2
7c7d7c6
 
 
 
 
 
 
dbf4eb2
57fe04c
 
 
 
 
 
 
 
a452638
57fe04c
 
7c7d7c6
57fe04c
 
 
aa13915
 
dbf4eb2
 
7c7d7c6
 
 
dbf4eb2
a452638
 
 
dbf4eb2
 
76b4c44
 
dbf4eb2
 
2b07bf2
 
 
dbf4eb2
2b07bf2
 
dbf4eb2
2b07bf2
 
 
dbf4eb2
2b07bf2
dbf4eb2
2b07bf2
38f2cf5
2b07bf2
 
a452638
2b07bf2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from io import BytesIO

import pandas as pd
import streamlit as st
import tokenizers
import torch
from transformers import Pipeline, pipeline

st.set_page_config(
    page_title="Zero-shot classification from tabular data",
    page_icon=None,
    layout="wide",
    initial_sidebar_state="auto",
    menu_items=None,
)


@st.cache(
    hash_funcs={
        torch.nn.parameter.Parameter: lambda _: None,
        tokenizers.Tokenizer: lambda _: None,
        tokenizers.AddedToken: lambda _: None,
    },
    allow_output_mutation=True,
    show_spinner=False,
)
def load_classifier() -> Pipeline:
    classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
    return classifier


with st.spinner(text="Setting stuff up related to the inference engine..."):
    classifier = load_classifier()

st.title("Zero-shot classification from tabular data")
st.text(
    "Upload an Excel table and perform zero-shot classification on a set of custom labels"
)

data = st.file_uploader(
    "Upload Excel file (it should contain a column named `text` in its header):"
)
labels = st.text_input("Enter comma-separated labels:")

# classify first N snippets only for faster inference

if st.button("Calculate labels"):

    labels_list = labels.split(",")
    table = pd.read_excel(data)
    table = table.loc[table["text"].apply(len) > 10].reset_index(drop=True)

    prog_bar = st.progress(0)
    preds = []

    for i in range(len(table)):
        preds.append(classifier(table.loc[i, "text"], labels)["labels"][0])
        prog_bar.progress((i + 1) / len(table))

    table["label"] = preds

    st.table(table[["text", "label"]])

    buf = BytesIO()
    table[["text", "label"]].to_excel(buf)

    st.download_button(
        label="Download table", data=buf.getvalue(), file_name="output.xlsx"
    )