bigbio_test / app.py
tensorized
testing scitail
50e5fc3
raw
history blame
9.01 kB
from collections import Counter
import numpy as np
import pandas as pd
import plotly.express as px
import streamlit as st
from datasets import load_dataset
from matplotlib import pyplot as plt
from matplotlib_venn import venn2, venn3
from ngram import get_tuples_manual_sentences
from rich import print as rprint
from bigbio.dataloader import BigBioConfigHelpers
# from matplotlib_venn_wordcloud import venn2_wordcloud, venn3_wordcloud
# vanilla tokenizer
def tokenizer(text, counter):
if not text:
return text, []
text = text.strip()
text = text.replace("\t", "")
text = text.replace("\n", "")
# split
text_list = text.split(" ")
return text, text_list
def norm(lengths):
mu = np.mean(lengths)
sigma = np.std(lengths)
return mu, sigma
def load_helper():
conhelps = BigBioConfigHelpers()
conhelps = conhelps.filtered(lambda x: x.dataset_name != "pubtator_central")
conhelps = conhelps.filtered(lambda x: x.is_bigbio_schema)
conhelps = conhelps.filtered(lambda x: not x.is_local)
rprint(
"loaded {} configs from {} datasets".format(
len(conhelps),
len(set([helper.dataset_name for helper in conhelps])),
)
)
return conhelps
_TEXT_MAPS = {
"bigbio_kb": ["text"],
"bigbio_text": ["text"],
"bigbio_qa": ["question", "context"],
"bigbio_te": ["premise", "hypothesis"],
"bigbio_tp": ["text_1", "text_2"],
"bigbio_pairs": ["text_1", "text_2"],
"bigbio_t2t": ["text_1", "text_2"],
}
IBM_COLORS = [
"#648fff",
"#dc267f",
"#ffb000",
"#fe6100",
"#785ef0",
"#000000",
"#ffffff",
]
N = 3
def token_length_per_entry(entry, schema, counter):
result = {}
if schema == "bigbio_kb":
for passage in entry["passages"]:
result_key = passage["type"]
for key in _TEXT_MAPS[schema]:
text = passage[key][0]
sents, ngrams = get_tuples_manual_sentences(text.lower(), N)
toks = [tok for sent in sents for tok in sent]
tups = ["_".join(tup) for tup in ngrams]
counter.update(tups)
result[result_key] = len(toks)
else:
for key in _TEXT_MAPS[schema]:
text = entry[key]
sents, ngrams = get_tuples_manual_sentences(text.lower(), N)
toks = [tok for sent in sents for tok in sent]
result[key] = len(toks)
tups = ["_".join(tup) for tup in ngrams]
counter.update(tups)
return result, counter
def parse_token_length_and_n_gram(dataset, data_config, st=None):
hist_data = []
n_gram_counters = []
rprint(data_config)
for split, data in dataset.items():
my_bar = st.progress(0)
total = len(data)
n_gram_counter = Counter()
for i, entry in enumerate(data):
my_bar.progress(int(i / total * 100))
result, n_gram_counter = token_length_per_entry(
entry, data_config.schema, n_gram_counter
)
result["total_token_length"] = sum([v for k, v in result.items()])
result["split"] = split
hist_data.append(result)
# remove single count
# n_gram_counter = Counter({x: count for x, count in n_gram_counter.items() if count > 1})
n_gram_counters.append(n_gram_counter)
my_bar.empty()
st.write("token lengths complete!")
return pd.DataFrame(hist_data), n_gram_counters
def center_title(fig):
fig.update_layout(
title={"y": 0.9, "x": 0.5, "xanchor": "center", "yanchor": "top"},
font=dict(
size=18,
),
)
return fig
def draw_histogram(hist_data, col_name, st=None):
fig = px.histogram(
hist_data,
x=col_name,
color="split",
color_discrete_sequence=IBM_COLORS,
marginal="box", # or violin, rug
barmode="group",
hover_data=hist_data.columns,
histnorm="probability",
nbins=20,
title=f"{col_name} distribution by split",
)
st.plotly_chart(center_title(fig), use_container_width=True)
def draw_bar(bar_data, x, y, st=None):
fig = px.bar(
bar_data,
x=x,
y=y,
color="split",
color_discrete_sequence=IBM_COLORS,
# marginal="box", # or violin, rug
barmode="group",
hover_data=bar_data.columns,
title=f"{y} distribution by split",
)
st.plotly_chart(center_title(fig), use_container_width=True)
def parse_metrics(metadata, st=None):
for k, m in metadata.items():
mattrs = m.__dict__
for m, attr in mattrs.items():
if type(attr) == int and attr > 0:
st.metric(label=f"{k}-{m}", value=attr)
def parse_counters(metadata):
metadata = metadata["train"] # using the training counter to fetch the names
counters = []
for k, v in metadata.__dict__.items():
if "counter" in k and len(v) > 0:
counters.append(k)
return counters
# generate the df for histogram
def parse_label_counter(metadata, counter_type):
hist_data = []
for split, m in metadata.items():
metadata_counter = getattr(m, counter_type)
for k, v in metadata_counter.items():
row = {}
row["labels"] = k
row[counter_type] = v
row["split"] = split
hist_data.append(row)
return pd.DataFrame(hist_data)
if __name__ == "__main__":
# load helpers
conhelps = load_helper()
configs_set = set()
for conhelper in conhelps:
configs_set.add(conhelper.dataset_name)
# st.write(sorted(configs_set))
# setup page, sidebar, columns
st.set_page_config(layout="wide")
s = st.session_state
if not s:
s.pressed_first_button = False
data_name = st.sidebar.selectbox("dataset", sorted(configs_set))
st.sidebar.write("you selected:", data_name)
st.header(f"Dataset stats for {data_name}")
# setup data configs
data_helpers = conhelps.for_dataset(data_name)
data_configs = [d.config for d in data_helpers]
data_config_names = [d.config.name for d in data_helpers]
data_config_name = st.sidebar.selectbox("config", set(data_config_names))
if st.sidebar.button("fetch") or s.pressed_first_button:
s.pressed_first_button = True
helper = conhelps.for_config_name(data_config_name)
metadata_helper = helper.get_metadata()
parse_metrics(metadata_helper, st.sidebar)
# load HF dataset
data_idx = data_config_names.index(data_config_name)
data_config = data_configs[data_idx]
# st.write(data_name)
dataset = load_dataset(
f"bigbio/{data_name}", name=data_config_name
)
ds = pd.DataFrame(dataset["train"])
st.write(ds)
# general token length
tok_hist_data, ngram_counters = parse_token_length_and_n_gram(
dataset, data_config, st.sidebar
)
# draw token distribution
draw_histogram(tok_hist_data, "total_token_length", st)
# general counter(s)
col1, col2 = st.columns([1, 6])
counters = parse_counters(metadata_helper)
counter_type = col1.selectbox("counter_type", counters)
label_df = parse_label_counter(metadata_helper, counter_type)
label_max = int(label_df[counter_type].max() - 1)
label_min = int(label_df[counter_type].min())
filter_value = col1.slider("counter_filter (min, max)", label_min, label_max)
label_df = label_df[label_df[counter_type] >= filter_value]
# draw bar chart for counter
draw_bar(label_df, "labels", counter_type, col2)
venn_fig, ax = plt.subplots()
if len(ngram_counters) == 2:
union_counter = ngram_counters[0] + ngram_counters[1]
print(ngram_counters[0].most_common(10))
print(ngram_counters[1].most_common(10))
total = len(union_counter.keys())
ngram_counter_sets = [
set(ngram_counter.keys()) for ngram_counter in ngram_counters
]
venn2(
ngram_counter_sets,
dataset.keys(),
set_colors=IBM_COLORS[:3],
subset_label_formatter=lambda x: f"{(x/total):1.0%}",
)
else:
union_counter = ngram_counters[0] + ngram_counters[1] + ngram_counters[2]
total = len(union_counter.keys())
ngram_counter_sets = [
set(ngram_counter.keys()) for ngram_counter in ngram_counters
]
venn3(
ngram_counter_sets,
dataset.keys(),
set_colors=IBM_COLORS[:4],
subset_label_formatter=lambda x: f"{(x/total):1.0%}",
)
venn_fig.suptitle(f"{N}-gram intersection for {data_name}", fontsize=20)
st.pyplot(venn_fig)
st.sidebar.button("Re-run")