File size: 3,393 Bytes
da27f68
 
 
 
 
 
 
 
 
 
 
 
 
 
610b4ea
feac64d
610b4ea
261d6fa
feac64d
 
 
da27f68
 
261d6fa
feac64d
610b4ea
feac64d
 
610b4ea
da27f68
feac64d
 
 
 
 
da27f68
feac64d
 
 
 
da27f68
feac64d
 
 
da27f68
feac64d
 
da27f68
 
664181d
 
 
feac64d
664181d
261d6fa
610b4ea
 
 
 
da27f68
 
 
610b4ea
 
feac64d
610b4ea
 
664181d
da27f68
261d6fa
664181d
238d2c6
78c5ea2
238d2c6
78c5ea2
238d2c6
664181d
da27f68
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sentencepiece as spm
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from valx import detect_profanity, detect_hate_speech
import gradio as gr

sp = spm.SentencePieceProcessor()
sp.Load("dungen_dev_preview.model")

model = tf.keras.models.load_model("dungen_dev_preview_model.keras")

max_seq_len = 25

def generate_text(seed_text, next_words=30, temperature=0.5):
    seed_text = seed_text.strip().lower()

    if "|" in seed_text:
        gr.Warning("The prompt should not contain the '|' character. Using default prompt.")
        seed_text = 'game name | '
    elif detect_profanity([seed_text], language='All'):
        gr.Warning("Profanity detected in the prompt, using the default prompt.")
        seed_text = 'game name | '
    elif (hate_speech_result := detect_hate_speech(seed_text)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
        gr.Warning('Harmful speech detected in the prompt, using default prompt.')
        seed_text = 'game name | '
    else:
        seed_text += ' | '

    generated_text = seed_text
    if generated_text != 'game name | ': # only generate if not the default prompt
        for _ in range(next_words):
            token_list = sp.encode_as_ids(generated_text)
            token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
            predicted = model.predict(token_list, verbose=0)[0]

            predicted = np.asarray(predicted).astype("float64")
            predicted = np.log(predicted + 1e-8) / temperature
            exp_preds = np.exp(predicted)
            predicted = exp_preds / np.sum(exp_preds)

            next_index = np.random.choice(len(predicted), p=predicted)
            next_token = sp.id_to_piece(next_index)
            generated_text += next_token

            if next_token.endswith('</s>') or next_token.endswith('<unk>'):
                break

    decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
    decoded = decoded.replace("</s>", "").replace("<unk>", "").strip()

    if '|' in decoded:
        decoded = decoded.split('|', 1)[1].strip()

    if any(detect_profanity([decoded], language='All')) or (hate_speech_result := detect_hate_speech(decoded)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
        gr.Warning("Flagged potentially harmful output.")
        decoded = 'Flagged Output'

    return decoded

demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", value="a female character name", max_lines=1),
        gr.Slider(1, 100, step=1, label='Next Words', value=30),
        gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic')
    ],
    outputs=gr.Textbox(label="Generated Names"),
    title='Dungen Dev - Name Generator',
    description='A prompt-based name generator for game developers. Dungen Dev is an experimental model, and may produce outputs that are inappropriate, biased, or potentially harmful and inaccurate. Caution is advised.',
    examples=[
        ["a male character name", 30, 0.5],
        ["a futuristic city name", 30, 0.5],
        ["an item name", 30, 0.5],
        ["a dark and mysterious forest name", 30, 0.5],
        ["an evil character name", 30, 0.5]
    ]
)

demo.launch()