|
import streamlit as st |
|
import pandas as pd |
|
from datasets import load_dataset, Dataset |
|
from random import sample |
|
from utils.pairwise_comparison import one_regard_computation |
|
import matplotlib.pyplot as plt |
|
import os |
|
|
|
|
|
st.title('Gender Bias Analysis in Text Generation') |
|
|
|
|
|
def check_password(): |
|
def password_entered(): |
|
if password_input == os.getenv('PASSWORD'): |
|
|
|
st.session_state['password_correct'] = True |
|
else: |
|
st.error("Incorrect Password, please try again.") |
|
|
|
password_input = st.text_input("Enter Password:", type="password") |
|
submit_button = st.button("Submit", on_click=password_entered) |
|
|
|
if submit_button and not st.session_state.get('password_correct', False): |
|
st.error("Please enter a valid password to access the demo.") |
|
|
|
|
|
if not st.session_state.get('password_correct', False): |
|
check_password() |
|
else: |
|
st.sidebar.success("Password Verified. Proceed with the demo.") |
|
|
|
if 'data_size' not in st.session_state: |
|
st.session_state['data_size'] = 10 |
|
if 'bold' not in st.session_state: |
|
bold = pd.DataFrame({}) |
|
bold_raw = pd.DataFrame(load_dataset("AlexaAI/bold", split="train")) |
|
for index, row in bold_raw.iterrows(): |
|
bold_raw_prompts = list(row['prompts']) |
|
bold_raw_wikipedia = list(row['wikipedia']) |
|
bold_expansion = zip(bold_raw_prompts, bold_raw_wikipedia) |
|
for bold_prompt, bold_wikipedia in bold_expansion: |
|
bold = bold._append( |
|
{'domain': row['domain'], 'name': row['name'], 'category': row['category'], 'prompts': bold_prompt, |
|
'wikipedia': bold_wikipedia}, ignore_index=True) |
|
st.session_state['bold'] = Dataset.from_pandas(bold) |
|
|
|
domain = st.selectbox( |
|
"Select the domain", |
|
pd.DataFrame(st.session_state['bold'])['domain'].unique()) |
|
domain_limited = [p for p in st.session_state['bold'] if p['domain'] == domain] |
|
|
|
st.session_state['sample_size'] = st.slider('Select number of samples per category:', min_value=1, max_value=50, |
|
value=st.session_state['data_size']) |
|
|
|
if st.button('Compute'): |
|
answer_dict = {} |
|
category_list = pd.DataFrame(domain_limited)['category'].unique().tolist() |
|
unique_pairs = [] |
|
ref_list = {} |
|
no_ref_list = {} |
|
for i in range(len(category_list)): |
|
o_one = category_list[i] |
|
with st.spinner(f'Computing regard results for {o_one.replace("_", " ")}'): |
|
st.session_state['rmr'] = one_regard_computation(o_one, st.session_state['bold'], |
|
st.session_state['sample_size']) |
|
answer_dict[o_one] = (st.session_state['rmr']) |
|
st.write(f'Regard results for {o_one.replace("_", " ")} computed successfully.') |
|
|
|
ref_list[o_one] = st.session_state['rmr']['ref_diff_mean']['positive'] \ |
|
- st.session_state['rmr']['ref_diff_mean']['negative'] |
|
no_ref_list[o_one] = st.session_state['rmr']['no_ref_diff_mean']['positive'] \ |
|
- st.session_state['rmr']['no_ref_diff_mean']['negative'] |
|
|
|
|
|
categories = ['GPT2', 'Wiki'] |
|
mp_gpt = st.session_state['rmr']['no_ref_diff_mean']['positive'] |
|
mn_gpt = st.session_state['rmr']['no_ref_diff_mean']['negative'] |
|
mo_gpt = 1 - (mp_gpt + mn_gpt) |
|
|
|
mp_wiki = mp_gpt - st.session_state['rmr']['ref_diff_mean']['positive'] |
|
mn_wiki = mn_gpt - st.session_state['rmr']['ref_diff_mean']['negative'] |
|
mo_wiki = 1 - (mn_wiki + mp_wiki) |
|
|
|
positive_m = [mp_gpt, mp_wiki] |
|
other_m = [mo_gpt, mo_wiki] |
|
negative_m = [mn_gpt, mn_wiki] |
|
|
|
|
|
fig_a, ax_a = plt.subplots() |
|
ax_a.bar(categories, negative_m, label='Negative', color='blue') |
|
ax_a.bar(categories, other_m, bottom=negative_m, label='Other', color='orange') |
|
ax_a.bar(categories, positive_m, bottom=[negative_m[i] + other_m[i] for i in range(len(negative_m))], |
|
label='Positive', color='green') |
|
|
|
plt.ylabel('Proportion') |
|
plt.title(f'GPT2 vs Wiki on {o_one.replace("_", " ")} regard') |
|
plt.legend() |
|
|
|
st.pyplot(fig_a) |
|
|
|
|
|
st.subheader(f'The comparison of absolute regard value in {domain.replace("_", " ")} by GPT2') |
|
st.bar_chart(no_ref_list) |
|
st.write(f'***Max difference of absolute regard values in the {domain.replace("_", " ")}:***') |
|
keys_with_max_value_no_ref = [key for key, value in no_ref_list.items() if value == max(no_ref_list.values())][0] |
|
keys_with_min_value_no_ref = [key for key, value in no_ref_list.items() if value == min(no_ref_list.values())][0] |
|
st.write(f' {keys_with_max_value_no_ref.replace("_", " ")} regard - {keys_with_min_value_no_ref.replace("_", " ")} regard =' |
|
f'{max(ref_list.values()) - min(ref_list.values())}') |
|
|
|
st.subheader(f'The comparison of regard value in {domain.replace("_", " ")} with references to Wikipedia by GPT2') |
|
st.bar_chart(ref_list) |
|
st.write(f'***Max difference of regard values in the {domain.replace("_", " ")} with references to Wikipedia:***') |
|
keys_with_max_value_ref = [key for key, value in ref_list.items() if value == max(ref_list.values())][0] |
|
keys_with_min_value_ref = [key for key, value in ref_list.items() if value == min(ref_list.values())][0] |
|
st.write(f' {keys_with_max_value_ref.replace("_", " ")} regard - {keys_with_min_value_ref.replace("_", " ")} regard = ' |
|
f'{max(ref_list.values()) - min(ref_list.values())}') |
|
|
|
|
|
|