Zekun Wu
		
	commited on
		
		
					Commit 
							
							·
						
						44466c7
	
1
								Parent(s):
							
							1da3bb7
								
update
Browse files- pages/1_Demo_1.py +63 -0
 - requirements.txt +3 -1
 - utils/__init__.py +0 -0
 - utils/dataset.py +0 -0
 - utils/metric.py +55 -0
 - utils/model.py +19 -0
 
    	
        pages/1_Demo_1.py
    CHANGED
    
    | 
         @@ -0,0 +1,63 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import streamlit as st
         
     | 
| 2 | 
         
            +
            import pandas as pd
         
     | 
| 3 | 
         
            +
            from datasets import load_dataset
         
     | 
| 4 | 
         
            +
            from random import sample
         
     | 
| 5 | 
         
            +
            from utils.metric import Regard
         
     | 
| 6 | 
         
            +
            from utils.model import gpt2
         
     | 
| 7 | 
         
            +
            import os
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Set up the Streamlit interface
         
     | 
| 10 | 
         
            +
            st.title('Gender Bias Analysis in Text Generation')
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            def check_password():
         
     | 
| 14 | 
         
            +
                def password_entered():
         
     | 
| 15 | 
         
            +
                    if password_input == os.getenv('PASSWORD'):
         
     | 
| 16 | 
         
            +
                        st.session_state['password_correct'] = True
         
     | 
| 17 | 
         
            +
                    else:
         
     | 
| 18 | 
         
            +
                        st.error("Incorrect Password, please try again.")
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                password_input = st.text_input("Enter Password:", type="password")
         
     | 
| 21 | 
         
            +
                submit_button = st.button("Submit", on_click=password_entered)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                if st.session_state.get('password_correct', False):
         
     | 
| 24 | 
         
            +
                    load_and_process_data()
         
     | 
| 25 | 
         
            +
                else:
         
     | 
| 26 | 
         
            +
                    st.error("Please enter a valid password to access the demo.")
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            def load_and_process_data():
         
     | 
| 30 | 
         
            +
                st.subheader('Loading and Processing Data')
         
     | 
| 31 | 
         
            +
                st.write('Loading the BOLD dataset...')
         
     | 
| 32 | 
         
            +
                bold = load_dataset("AlexaAI/bold", split="train")
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                st.write('Sampling 10 female and male American actors...')
         
     | 
| 35 | 
         
            +
                female_bold = sample([p for p in bold if p['category'] == 'American_actresses'], 10)
         
     | 
| 36 | 
         
            +
                male_bold = sample([p for p in bold if p['category'] == 'American_actors'], 10)
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
                male_prompts = [p['prompts'][0] for p in male_bold]
         
     | 
| 39 | 
         
            +
                female_prompts = [p['prompts'][0] for p in female_bold]
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                GPT2 = gpt2()
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                st.write('Generating text for male prompts...')
         
     | 
| 44 | 
         
            +
                male_generation = GPT2.generate_text(male_prompts, pad_token_id=50256, max_length=50, do_sample=False)
         
     | 
| 45 | 
         
            +
                male_continuations = [gen.replace(prompt, '') for gen, prompt in zip(male_generation, male_prompts)]
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                st.write('Generating text for female prompts...')
         
     | 
| 48 | 
         
            +
                female_generation = GPT2.generate_text(female_prompts, pad_token_id=50256, max_length=50, do_sample=False)
         
     | 
| 49 | 
         
            +
                female_continuations = [gen.replace(prompt, '') for gen, prompt in zip(female_generation, female_prompts)]
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                st.write('Generated {} male continuations'.format(len(male_continuations)))
         
     | 
| 52 | 
         
            +
                st.write('Generated {} female continuations'.format(len(female_continuations)))
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
                st.subheader('Sample Generated Texts')
         
     | 
| 55 | 
         
            +
                st.write('**Male Prompt:**', male_prompts[0])
         
     | 
| 56 | 
         
            +
                st.write('**Male Continuation:**', male_continuations[0])
         
     | 
| 57 | 
         
            +
                st.write('**Female Prompt:**', female_prompts[0])
         
     | 
| 58 | 
         
            +
                st.write('**Female Continuation:**', female_continuations[0])
         
     | 
| 59 | 
         
            +
             
     | 
| 60 | 
         
            +
                regard = Regard("compare")
         
     | 
| 61 | 
         
            +
                st.write('Computing regard results to compare male and female continuations...')
         
     | 
| 62 | 
         
            +
                regard_results = regard.compute(data=male_continuations, references=female_continuations)
         
     | 
| 63 | 
         
            +
                st.write(regard
         
     | 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -1 +1,3 @@ 
     | 
|
| 1 | 
         
            -
            openai
         
     | 
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            openai
         
     | 
| 2 | 
         
            +
            transformers
         
     | 
| 3 | 
         
            +
            torch==2.0.1
         
     | 
    	
        utils/__init__.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        utils/dataset.py
    ADDED
    
    | 
         
            File without changes
         
     | 
    	
        utils/metric.py
    ADDED
    
    | 
         @@ -0,0 +1,55 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from collections import defaultdict
         
     | 
| 2 | 
         
            +
            from statistics import mean
         
     | 
| 3 | 
         
            +
            from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
         
     | 
| 4 | 
         
            +
             
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
             
     | 
| 7 | 
         
            +
            class Regard:
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                def __init__(self, config_name):
         
     | 
| 10 | 
         
            +
                    self.config_name = config_name
         
     | 
| 11 | 
         
            +
                    regard_tokenizer = AutoTokenizer.from_pretrained("sasha/regardv3")
         
     | 
| 12 | 
         
            +
                    regard_model = AutoModelForSequenceClassification.from_pretrained("sasha/regardv3")
         
     | 
| 13 | 
         
            +
                    self.regard_classifier = pipeline(
         
     | 
| 14 | 
         
            +
                        "text-classification", model=regard_model, top_k=4, tokenizer=regard_tokenizer, truncation=True)
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
                def regard(self,group):
         
     | 
| 17 | 
         
            +
                    group_scores = defaultdict(list)
         
     | 
| 18 | 
         
            +
                    group_regard = self.regard_classifier(group)
         
     | 
| 19 | 
         
            +
                    for pred in group_regard:
         
     | 
| 20 | 
         
            +
                        for pred_score in pred:
         
     | 
| 21 | 
         
            +
                            group_scores[pred_score["label"]].append(pred_score["score"])
         
     | 
| 22 | 
         
            +
                    return group_regard, dict(group_scores)
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
                def compute(
         
     | 
| 25 | 
         
            +
                    self,
         
     | 
| 26 | 
         
            +
                    data,
         
     | 
| 27 | 
         
            +
                    references=None,
         
     | 
| 28 | 
         
            +
                    aggregation=None,
         
     | 
| 29 | 
         
            +
                ):
         
     | 
| 30 | 
         
            +
                    if self.config_name == "compare":
         
     | 
| 31 | 
         
            +
                        pred_scores, pred_regard = self.regard(data)
         
     | 
| 32 | 
         
            +
                        ref_scores, ref_regard = self.regard(references)
         
     | 
| 33 | 
         
            +
                        pred_mean = {k: mean(v) for k, v in pred_regard.items()}
         
     | 
| 34 | 
         
            +
                        pred_max = {k: max(v) for k, v in pred_regard.items()}
         
     | 
| 35 | 
         
            +
                        ref_mean = {k: mean(v) for k, v in ref_regard.items()}
         
     | 
| 36 | 
         
            +
                        ref_max = {k: max(v) for k, v in ref_regard.items()}
         
     | 
| 37 | 
         
            +
                        if aggregation == "maximum":
         
     | 
| 38 | 
         
            +
                            return {
         
     | 
| 39 | 
         
            +
                                "max_data_regard": pred_max,
         
     | 
| 40 | 
         
            +
                                "max_references_regard": ref_max,
         
     | 
| 41 | 
         
            +
                            }
         
     | 
| 42 | 
         
            +
                        elif aggregation == "average":
         
     | 
| 43 | 
         
            +
                            return {"average_data_regard": pred_mean, "average_references_regard": ref_mean}
         
     | 
| 44 | 
         
            +
                        else:
         
     | 
| 45 | 
         
            +
                            return {"regard_difference": {key: pred_mean[key] - ref_mean.get(key, 0) for key in pred_mean}}
         
     | 
| 46 | 
         
            +
                    else:
         
     | 
| 47 | 
         
            +
                        pred_scores, pred_regard = self.regard(data)
         
     | 
| 48 | 
         
            +
                        pred_mean = {k: mean(v) for k, v in pred_regard.items()}
         
     | 
| 49 | 
         
            +
                        pred_max = {k: max(v) for k, v in pred_regard.items()}
         
     | 
| 50 | 
         
            +
                        if aggregation == "maximum":
         
     | 
| 51 | 
         
            +
                            return {"max_regard": pred_max}
         
     | 
| 52 | 
         
            +
                        elif aggregation == "average":
         
     | 
| 53 | 
         
            +
                            return {"average_regard": pred_mean}
         
     | 
| 54 | 
         
            +
                        else:
         
     | 
| 55 | 
         
            +
                            return {"regard": pred_scores}
         
     | 
    	
        utils/model.py
    ADDED
    
    | 
         @@ -0,0 +1,19 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from transformers import pipeline, AutoTokenizer
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            class gpt2:
         
     | 
| 5 | 
         
            +
                def __init__(self,device="cpu"):
         
     | 
| 6 | 
         
            +
                    self.text_generation = pipeline("text-generation", model="gpt2",device=device)
         
     | 
| 7 | 
         
            +
                    self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
                def generate_text(self,**kwargs):
         
     | 
| 10 | 
         
            +
                    results = self.text_generation(**kwargs)
         
     | 
| 11 | 
         
            +
             
     | 
| 12 | 
         
            +
                    return [item['generated_text'] for item in results[0]]
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                def get_tokenizer(self):
         
     | 
| 15 | 
         
            +
                    return self.tokenizer
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
            if __name__ == '__main__':
         
     | 
| 18 | 
         
            +
                gpt2 = gpt2()
         
     | 
| 19 | 
         
            +
                print(gpt2.generate_text(["Hello, how are you?","I am fine, thank you."]))
         
     |