Zekun Wu commited on
Commit
e4768e1
·
1 Parent(s): dc1e1b6
Files changed (1) hide show
  1. pages/1_Demo_1.py +2 -2
pages/1_Demo_1.py CHANGED
@@ -62,13 +62,13 @@ else:
62
  do_sample=False, truncation=True)
63
 
64
  print(male_generation)
65
- st.session_state['male_continuations'] = [gen['generated_text'].replace(prompt, '') for gen, prompt in
66
  zip(male_generation, st.session_state['male_prompts'])]
67
 
68
  st.write('Generating text for female prompts...')
69
  female_generation = GPT2.text_generation(st.session_state['female_prompts'], pad_token_id=50256,
70
  max_length=50, do_sample=False, truncation=True)
71
- st.session_state['female_continuations'] = [gen['generated_text'].replace(prompt, '') for gen, prompt in
72
  zip(female_generation, st.session_state['female_prompts'])]
73
 
74
  st.write('Generated {} male continuations'.format(len(st.session_state['male_continuations'])))
 
62
  do_sample=False, truncation=True)
63
 
64
  print(male_generation)
65
+ st.session_state['male_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
66
  zip(male_generation, st.session_state['male_prompts'])]
67
 
68
  st.write('Generating text for female prompts...')
69
  female_generation = GPT2.text_generation(st.session_state['female_prompts'], pad_token_id=50256,
70
  max_length=50, do_sample=False, truncation=True)
71
+ st.session_state['female_continuations'] = [gen[0]['generated_text'].replace(prompt, '') for gen, prompt in
72
  zip(female_generation, st.session_state['female_prompts'])]
73
 
74
  st.write('Generated {} male continuations'.format(len(st.session_state['male_continuations'])))