|
import os |
|
import re |
|
import pandas as pd |
|
import plotly.express as px |
|
import streamlit as st |
|
|
|
st.set_page_config(layout="wide") |
|
DATA_FILE = "data/aclanthology2016-23_specter2_base.json" |
|
THEMES = {"cluster": "fall", "year": "mint", "source": "phase"} |
|
|
|
|
|
def to_string_authors(list_of_authors): |
|
if len(list_of_authors) > 5: |
|
return ", ".join(list_of_authors[:5]) + ", et al." |
|
elif len(list_of_authors) > 2: |
|
return ", ".join(list_of_authors[:-1]) + ", and " + list_of_authors[-1] |
|
else: |
|
return " and ".join(list_of_authors) |
|
|
|
|
|
def load_df(data_file: os.PathLike): |
|
df = pd.read_json(data_file, orient="records") |
|
df["x"] = df["point2d"].apply(lambda x: x[0]) |
|
df["y"] = df["point2d"].apply(lambda x: x[1]) |
|
|
|
df["authors_trimmed"] = df.authors.apply( |
|
lambda row: to_string_authors( |
|
[(x[x.index(",") + 1 :].strip() + " " + x.split(",")[0].strip()) if "," in x else x for x in row] |
|
) |
|
) |
|
|
|
if "publication_type" in df.columns: |
|
df["type"] = df["publication_type"] |
|
df = df.drop(columns=["point2d", "publication_type"]) |
|
else: |
|
df = df.drop(columns=["point2d"]) |
|
return df |
|
|
|
|
|
@st.cache_data |
|
def load_dataframe(): |
|
return load_df(DATA_FILE) |
|
|
|
|
|
DF = load_dataframe() |
|
DF["opacity"] = 0.04 |
|
min_year, max_year = DF["year"].min(), DF["year"].max() |
|
|
|
with st.sidebar: |
|
venues = st.multiselect( |
|
"Venues", |
|
["ACL", "EMNLP", "NAACL", "TACL"], |
|
["ACL", "EMNLP", "NAACL", "TACL"], |
|
) |
|
|
|
start_year, end_year = st.select_slider( |
|
"Publication year", |
|
options=[str(y) for y in range(min_year, max_year + 1)], |
|
value=(str(min_year), str(max_year)), |
|
) |
|
author_names = st.text_input("Author names (separated by comma)") |
|
|
|
title = st.text_input("Title") |
|
|
|
start_year = int(start_year) |
|
end_year = int(end_year) |
|
df_mask = (DF["year"] >= start_year) & (DF["year"] <= end_year) |
|
if 0 < len(venues) < 4: |
|
selected_venues = [v.lower() for v in venues] |
|
df_mask = df_mask & DF["source"].isin(selected_venues) |
|
elif not venues: |
|
st.write(":red[Please select a venue]") |
|
|
|
if author_names: |
|
authors = [a.strip() for a in author_names.split(",")] |
|
author_mask = DF.authors.apply( |
|
lambda row: all(any(re.match(rf".*{a}.*", x, re.IGNORECASE) for x in row) for a in authors) |
|
) |
|
df_mask = df_mask & author_mask |
|
|
|
if title: |
|
df_mask = df_mask & DF.title.apply(lambda x: title.lower() in x.lower()) |
|
|
|
DF.loc[df_mask, "opacity"] = 1.0 |
|
st.write(f"Number of points: {DF[df_mask].shape[0]}") |
|
|
|
color = st.selectbox("Color", ("cluster", "year", "source")) |
|
|
|
|
|
fig = px.scatter( |
|
DF, |
|
x="x", |
|
y="y", |
|
opacity=DF["opacity"], |
|
color=color, |
|
width=1000, |
|
height=800, |
|
custom_data=("title", "authors_trimmed", "year", "source", "type"), |
|
color_continuous_scale=THEMES[color], |
|
) |
|
fig.update_traces( |
|
hovertemplate="<b>%{customdata[0]}</b><br>%{customdata[1]}<br>%{customdata[2]}<br><i>%{customdata[3]}</i>" |
|
) |
|
fig.update_layout( |
|
|
|
showlegend=False, |
|
font=dict( |
|
family="Times New Roman", |
|
size=30, |
|
), |
|
hoverlabel=dict( |
|
align="left", |
|
font_size=14, |
|
font_family="Rockwell", |
|
namelength=-1, |
|
), |
|
) |
|
fig.update_xaxes(title="") |
|
fig.update_yaxes(title="") |
|
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|