LLM-Open-Generation-Bias / pages /4_Demo_compute_by_domain.py
ProgU
wider functions covering domain-wise-comparison and selected pairs comparisons
4e02702
raw
history blame
6.01 kB
import streamlit as st
import pandas as pd
from datasets import load_dataset, Dataset
from random import sample
from utils.pairwise_comparison import one_regard_computation
import matplotlib.pyplot as plt
import os
# Set up the Streamlit interface
st.title('Gender Bias Analysis in Text Generation')
def check_password():
def password_entered():
if password_input == os.getenv('PASSWORD'):
# if password_input == " ":
st.session_state['password_correct'] = True
else:
st.error("Incorrect Password, please try again.")
password_input = st.text_input("Enter Password:", type="password")
submit_button = st.button("Submit", on_click=password_entered)
if submit_button and not st.session_state.get('password_correct', False):
st.error("Please enter a valid password to access the demo.")
if not st.session_state.get('password_correct', False):
check_password()
else:
st.sidebar.success("Password Verified. Proceed with the demo.")
if 'data_size' not in st.session_state:
st.session_state['data_size'] = 10
if 'bold' not in st.session_state:
bold = pd.DataFrame({})
bold_raw = pd.DataFrame(load_dataset("AlexaAI/bold", split="train"))
for index, row in bold_raw.iterrows():
bold_raw_prompts = list(row['prompts'])
bold_raw_wikipedia = list(row['wikipedia'])
bold_expansion = zip(bold_raw_prompts, bold_raw_wikipedia)
for bold_prompt, bold_wikipedia in bold_expansion:
bold = bold._append(
{'domain': row['domain'], 'name': row['name'], 'category': row['category'], 'prompts': bold_prompt,
'wikipedia': bold_wikipedia}, ignore_index=True)
st.session_state['bold'] = Dataset.from_pandas(bold)
domain = st.selectbox(
"Select the domain",
pd.DataFrame(st.session_state['bold'])['domain'].unique())
domain_limited = [p for p in st.session_state['bold'] if p['domain'] == domain]
st.session_state['sample_size'] = st.slider('Select number of samples per category:', min_value=1, max_value=50,
value=st.session_state['data_size'])
if st.button('Compute'):
answer_dict = {}
category_list = pd.DataFrame(domain_limited)['category'].unique().tolist()
unique_pairs = []
ref_list = {}
no_ref_list = {}
for i in range(len(category_list)):
o_one = category_list[i]
with st.spinner(f'Computing regard results for {o_one.replace("_", " ")}'):
st.session_state['rmr'] = one_regard_computation(o_one, st.session_state['bold'],
st.session_state['sample_size'])
answer_dict[o_one] = (st.session_state['rmr'])
st.write(f'Regard results for {o_one.replace("_", " ")} computed successfully.')
# st.json(answer_dict[o_one])
ref_list[o_one] = st.session_state['rmr']['ref_diff_mean']['positive'] \
- st.session_state['rmr']['ref_diff_mean']['negative']
no_ref_list[o_one] = st.session_state['rmr']['no_ref_diff_mean']['positive'] \
- st.session_state['rmr']['no_ref_diff_mean']['negative']
# Plotting
categories = ['GPT2', 'Wiki']
mp_gpt = st.session_state['rmr']['no_ref_diff_mean']['positive']
mn_gpt = st.session_state['rmr']['no_ref_diff_mean']['negative']
mo_gpt = 1 - (mp_gpt + mn_gpt)
mp_wiki = mp_gpt - st.session_state['rmr']['ref_diff_mean']['positive']
mn_wiki = mn_gpt - st.session_state['rmr']['ref_diff_mean']['negative']
mo_wiki = 1 - (mn_wiki + mp_wiki)
positive_m = [mp_gpt, mp_wiki]
other_m = [mo_gpt, mo_wiki]
negative_m = [mn_gpt, mn_wiki]
fig_a, ax_a = plt.subplots()
ax_a.bar(categories, negative_m, label='Negative', color='blue')
ax_a.bar(categories, other_m, bottom=negative_m, label='Other', color='orange')
ax_a.bar(categories, positive_m, bottom=[negative_m[i] + other_m[i] for i in range(len(negative_m))],
label='Positive', color='green')
plt.ylabel('Proportion')
plt.title(f'GPT2 vs Wiki on {o_one.replace("_", " ")} regard')
plt.legend()
st.pyplot(fig_a)
st.subheader(f'The comparison of absolute regard value in {domain.replace("_", " ")} by GPT2')
st.bar_chart(no_ref_list)
st.write(f'***Max difference of absolute regard values in the {domain.replace("_", " ")}:***')
keys_with_max_value_no_ref = [key for key, value in no_ref_list.items() if value == max(no_ref_list.values())][0]
keys_with_min_value_no_ref = [key for key, value in no_ref_list.items() if value == min(no_ref_list.values())][0]
st.write(f' {keys_with_max_value_no_ref.replace("_", " ")} regard - {keys_with_min_value_no_ref.replace("_", " ")} regard ='
f'{max(ref_list.values()) - min(ref_list.values())}')
st.subheader(f'The comparison of regard value in {domain.replace("_", " ")} with references to Wikipedia by GPT2')
st.bar_chart(ref_list)
st.write(f'***Max difference of regard values in the {domain.replace("_", " ")} with references to Wikipedia:***')
keys_with_max_value_ref = [key for key, value in ref_list.items() if value == max(ref_list.values())][0]
keys_with_min_value_ref = [key for key, value in ref_list.items() if value == min(ref_list.values())][0]
st.write(f' {keys_with_max_value_ref.replace("_", " ")} regard - {keys_with_min_value_ref.replace("_", " ")} regard = '
f'{max(ref_list.values()) - min(ref_list.values())}')