JohanBeytell commited on
Commit
da27f68
·
verified ·
1 Parent(s): d0f0904

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sentencepiece as spm
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
5
+ import re
6
+ from valx import detect_profanity, detect_hate_speech
7
+ import gradio as gr
8
+
9
+ sp = spm.SentencePieceProcessor()
10
+ sp.Load("dungen_dev_preview.model")
11
+
12
+ model = tf.keras.models.load_model("dungen_dev_preview_model.keras")
13
+
14
+ max_seq_len = 25
15
+
16
+ def generate_text(seed_text, next_words=5, temperature=1.0):
17
+ seed_text = seed_text.lowercase() + ' | '
18
+ hate_speech = detect_hate_speech(seed_text)
19
+ profanity = detect_profanity([seed_text], language='All')
20
+
21
+ if len(profanity) > 0:
22
+ gr.Warning("Profanity detected in the prompt, using the default prompt.")
23
+ seed_text = 'game name | '
24
+ else:
25
+ if hate_speech == ['Hate Speech']:
26
+ gr.Warning('Hate speech detected in the seed text, using an empty seed text.')
27
+ seed_text = 'game name | '
28
+ elif hate_speech == ['Offensive Speech']:
29
+ gr.Warning('Offensive speech detected in the seed text, using an empty seed text.')
30
+ seed_text = 'game name | '
31
+
32
+ generated_text = seed_text
33
+ for _ in range(next_words):
34
+ token_list = sp.encode_as_ids(generated_text)
35
+ token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
36
+ predicted = model.predict(token_list, verbose=0)[0]
37
+
38
+ # Apply temperature
39
+ predicted = np.asarray(predicted).astype("float64")
40
+ predicted = np.log(predicted) / temperature
41
+ exp_preds = np.exp(predicted)
42
+ predicted = exp_preds / np.sum(exp_preds)
43
+
44
+ next_index = np.random.choice(len(predicted), p=predicted)
45
+ next_token = sp.id_to_piece(next_index)
46
+ generated_text += next_token
47
+
48
+ if next_token.endswith('</s>') or next_token.endswith('<unk>'):
49
+ break
50
+
51
+ decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
52
+ decoded = decoded.replace("</s>", "")
53
+ decoded = decoded.replace("<unk>", "")
54
+ cleaned_text = decoded.strip()
55
+
56
+ hate_speech2 = detect_hate_speech(cleaned_text)
57
+ profanity2 = detect_profanity([cleaned_text], language='All')
58
+
59
+ if len(profanity2) > 0:
60
+ gr.Warning("Flagged potentially harmful output.")
61
+ cleaned_text = 'Flagged Output'
62
+ else:
63
+ if hate_speech2 == ['Hate Speech']:
64
+ gr.Warning('Flagged potentially harmful output.')
65
+ cleaned_text = 'Flagged Output'
66
+ elif hate_speech2 == ['Offensive Speech']:
67
+ gr.Warning('Flagged potentially harmful output.')
68
+ cleaned_text = 'Flagged Output'
69
+
70
+ return cleaned_text
71
+
72
+ demo = gr.Interface(
73
+ fn=generate_text,
74
+ inputs=[label="Prompt", value="a female character name", max_lines=1), gr.Slider(1,100, step=1, label='Next Words', value=30), gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic')],
75
+ outputs=[gr.Dataframe(row_count = (2, "dynamic"), col_count=(1, "fixed"), label="Generated Names", headers=["Names"])],
76
+ title='Dungen Dev - Name Generator',
77
+ description='A prompt-based name generator for game developers.'
78
+ )
79
+
80
+ demo.launch()