JohanBeytell commited on
Commit
feac64d
·
verified ·
1 Parent(s): 238d2c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -27
app.py CHANGED
@@ -13,47 +13,46 @@ model = tf.keras.models.load_model("dungen_dev_preview_model.keras")
13
  max_seq_len = 25
14
 
15
  def generate_text(seed_text, next_words=30, temperature=0.5):
16
- seed_text = seed_text.strip()
17
- seed_text = seed_text.lower() + ' | '
18
- hate_speech = detect_hate_speech(seed_text)
19
- profanity = detect_profanity([seed_text], language='All')
20
 
21
- if profanity:
 
 
 
22
  gr.Warning("Profanity detected in the prompt, using the default prompt.")
23
  seed_text = 'game name | '
24
- elif hate_speech and hate_speech[0] in ['Hate Speech', 'Offensive Speech']:
25
- gr.Warning('Harmful speech detected in the seed text, using default prompt.')
26
  seed_text = 'game name | '
 
 
27
 
28
  generated_text = seed_text
29
- for _ in range(next_words):
30
- token_list = sp.encode_as_ids(generated_text)
31
- token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
32
- predicted = model.predict(token_list, verbose=0)[0]
 
33
 
34
- predicted = np.asarray(predicted).astype("float64")
35
- predicted = np.log(predicted + 1e-8) / temperature
36
- exp_preds = np.exp(predicted)
37
- predicted = exp_preds / np.sum(exp_preds)
38
 
39
- next_index = np.random.choice(len(predicted), p=predicted)
40
- next_token = sp.id_to_piece(next_index)
41
- generated_text += next_token
42
 
43
- if next_token.endswith('</s>') or next_token.endswith('<unk>'):
44
- break
45
 
46
  decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
47
  decoded = decoded.replace("</s>", "").replace("<unk>", "").strip()
48
 
49
- # Remove the prompt from the generated text
50
  if '|' in decoded:
51
- decoded = decoded.split('|', 1)[1].strip() #Split at the first occurence of '|' and take the second part
52
-
53
- hate_speech2 = detect_hate_speech(decoded)
54
- profanity2 = detect_profanity([decoded], language='All')
55
 
56
- if profanity2 or (hate_speech2 and hate_speech2[0] in ['Hate Speech', 'Offensive Speech']):
57
  gr.Warning("Flagged potentially harmful output.")
58
  decoded = 'Flagged Output'
59
 
@@ -63,7 +62,7 @@ demo = gr.Interface(
63
  fn=generate_text,
64
  inputs=[
65
  gr.Textbox(label="Prompt", value="a female character name", max_lines=1),
66
- gr.Slider(1, 50, step=1, label='Next Words', value=30),
67
  gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic')
68
  ],
69
  outputs=gr.Textbox(label="Generated Names"),
 
13
  max_seq_len = 25
14
 
15
  def generate_text(seed_text, next_words=30, temperature=0.5):
16
+ seed_text = seed_text.strip().lower()
 
 
 
17
 
18
+ if "|" in seed_text: # check for | in seed_text
19
+ gr.Warning("The prompt should not contain the '|' character. Using default prompt.")
20
+ seed_text = 'game name | '
21
+ elif detect_profanity([seed_text], language='All'):
22
  gr.Warning("Profanity detected in the prompt, using the default prompt.")
23
  seed_text = 'game name | '
24
+ elif detect_hate_speech(seed_text) and detect_hate_speech(seed_text)[0] in ['Hate Speech', 'Offensive Speech']:
25
+ gr.Warning('Harmful speech detected in the prompt, using default prompt.')
26
  seed_text = 'game name | '
27
+ else:
28
+ seed_text += ' | '
29
 
30
  generated_text = seed_text
31
+ if generated_text != 'game name | ': # only generate if not the default prompt
32
+ for _ in range(next_words):
33
+ token_list = sp.encode_as_ids(generated_text)
34
+ token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
35
+ predicted = model.predict(token_list, verbose=0)[0]
36
 
37
+ predicted = np.asarray(predicted).astype("float64")
38
+ predicted = np.log(predicted + 1e-8) / temperature
39
+ exp_preds = np.exp(predicted)
40
+ predicted = exp_preds / np.sum(exp_preds)
41
 
42
+ next_index = np.random.choice(len(predicted), p=predicted)
43
+ next_token = sp.id_to_piece(next_index)
44
+ generated_text += next_token
45
 
46
+ if next_token.endswith('</s>') or next_token.endswith('<unk>'):
47
+ break
48
 
49
  decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
50
  decoded = decoded.replace("</s>", "").replace("<unk>", "").strip()
51
 
 
52
  if '|' in decoded:
53
+ decoded = decoded.split('|', 1)[1].strip()
 
 
 
54
 
55
+ if detect_profanity([decoded], language='All') or (detect_hate_speech(decoded) and detect_hate_speech(decoded)[0] in ['Hate Speech', 'Offensive Speech']):
56
  gr.Warning("Flagged potentially harmful output.")
57
  decoded = 'Flagged Output'
58
 
 
62
  fn=generate_text,
63
  inputs=[
64
  gr.Textbox(label="Prompt", value="a female character name", max_lines=1),
65
+ gr.Slider(1, 100, step=1, label='Next Words', value=30),
66
  gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic')
67
  ],
68
  outputs=gr.Textbox(label="Generated Names"),