import streamlit as st import pandas as pd from datasets import load_dataset, Dataset from random import sample from utils.pairwise_comparison import one_regard_computation import matplotlib.pyplot as plt import os # Set up the Streamlit interface st.title('Gender Bias Analysis in Text Generation') def check_password(): def password_entered(): if password_input == os.getenv('PASSWORD'): # if password_input == " ": st.session_state['password_correct'] = True else: st.error("Incorrect Password, please try again.") password_input = st.text_input("Enter Password:", type="password") submit_button = st.button("Submit", on_click=password_entered) if submit_button and not st.session_state.get('password_correct', False): st.error("Please enter a valid password to access the demo.") if not st.session_state.get('password_correct', False): check_password() else: st.sidebar.success("Password Verified. Proceed with the demo.") if 'data_size' not in st.session_state: st.session_state['data_size'] = 10 if 'bold' not in st.session_state: bold = pd.DataFrame({}) bold_raw = pd.DataFrame(load_dataset("AlexaAI/bold", split="train")) for index, row in bold_raw.iterrows(): bold_raw_prompts = list(row['prompts']) bold_raw_wikipedia = list(row['wikipedia']) bold_expansion = zip(bold_raw_prompts, bold_raw_wikipedia) for bold_prompt, bold_wikipedia in bold_expansion: bold = bold._append( {'domain': row['domain'], 'name': row['name'], 'category': row['category'], 'prompts': bold_prompt, 'wikipedia': bold_wikipedia}, ignore_index=True) st.session_state['bold'] = Dataset.from_pandas(bold) domain = st.selectbox( "Select the domain", pd.DataFrame(st.session_state['bold'])['domain'].unique()) domain_limited = [p for p in st.session_state['bold'] if p['domain'] == domain] st.session_state['sample_size'] = st.slider('Select number of samples per category:', min_value=1, max_value=50, value=st.session_state['data_size']) if st.button('Compute'): answer_dict = {} category_list = pd.DataFrame(domain_limited)['category'].unique().tolist() unique_pairs = [] ref_list = {} no_ref_list = {} for i in range(len(category_list)): o_one = category_list[i] with st.spinner(f'Computing regard results for {o_one.replace("_", " ")}'): st.session_state['rmr'] = one_regard_computation(o_one, st.session_state['bold'], st.session_state['sample_size']) answer_dict[o_one] = (st.session_state['rmr']) st.write(f'Regard results for {o_one.replace("_", " ")} computed successfully.') # st.json(answer_dict[o_one]) ref_list[o_one] = st.session_state['rmr']['ref_diff_mean']['positive'] \ - st.session_state['rmr']['ref_diff_mean']['negative'] no_ref_list[o_one] = st.session_state['rmr']['no_ref_diff_mean']['positive'] \ - st.session_state['rmr']['no_ref_diff_mean']['negative'] # Plotting categories = ['GPT2', 'Wiki'] mp_gpt = st.session_state['rmr']['no_ref_diff_mean']['positive'] mn_gpt = st.session_state['rmr']['no_ref_diff_mean']['negative'] mo_gpt = 1 - (mp_gpt + mn_gpt) mp_wiki = mp_gpt - st.session_state['rmr']['ref_diff_mean']['positive'] mn_wiki = mn_gpt - st.session_state['rmr']['ref_diff_mean']['negative'] mo_wiki = 1 - (mn_wiki + mp_wiki) positive_m = [mp_gpt, mp_wiki] other_m = [mo_gpt, mo_wiki] negative_m = [mn_gpt, mn_wiki] fig_a, ax_a = plt.subplots() ax_a.bar(categories, negative_m, label='Negative', color='blue') ax_a.bar(categories, other_m, bottom=negative_m, label='Other', color='orange') ax_a.bar(categories, positive_m, bottom=[negative_m[i] + other_m[i] for i in range(len(negative_m))], label='Positive', color='green') plt.ylabel('Proportion') plt.title(f'GPT2 vs Wiki on {o_one.replace("_", " ")} regard') plt.legend() st.pyplot(fig_a) st.subheader(f'The comparison of absolute regard value in {domain.replace("_", " ")} by GPT2') st.bar_chart(no_ref_list) st.write(f'***Max difference of absolute regard values in the {domain.replace("_", " ")}:***') keys_with_max_value_no_ref = [key for key, value in no_ref_list.items() if value == max(no_ref_list.values())][0] keys_with_min_value_no_ref = [key for key, value in no_ref_list.items() if value == min(no_ref_list.values())][0] st.write(f' {keys_with_max_value_no_ref.replace("_", " ")} regard - {keys_with_min_value_no_ref.replace("_", " ")} regard =' f'{max(ref_list.values()) - min(ref_list.values())}') st.subheader(f'The comparison of regard value in {domain.replace("_", " ")} with references to Wikipedia by GPT2') st.bar_chart(ref_list) st.write(f'***Max difference of regard values in the {domain.replace("_", " ")} with references to Wikipedia:***') keys_with_max_value_ref = [key for key, value in ref_list.items() if value == max(ref_list.values())][0] keys_with_min_value_ref = [key for key, value in ref_list.items() if value == min(ref_list.values())][0] st.write(f' {keys_with_max_value_ref.replace("_", " ")} regard - {keys_with_min_value_ref.replace("_", " ")} regard = ' f'{max(ref_list.values()) - min(ref_list.values())}')