File size: 3,088 Bytes
da27f68
 
 
 
 
 
 
 
 
 
 
 
 
 
610b4ea
 
da27f68
 
610b4ea
664181d
da27f68
 
664181d
610b4ea
 
 
da27f68
 
 
 
 
 
 
664181d
da27f68
 
 
 
 
 
 
 
 
 
 
664181d
 
 
 
 
 
610b4ea
 
664181d
 
610b4ea
 
 
 
da27f68
 
 
610b4ea
 
664181d
610b4ea
 
664181d
da27f68
664181d
 
78c5ea2
 
 
 
 
664181d
da27f68
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sentencepiece as spm
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from valx import detect_profanity, detect_hate_speech
import gradio as gr

sp = spm.SentencePieceProcessor()
sp.Load("dungen_dev_preview.model")

model = tf.keras.models.load_model("dungen_dev_preview_model.keras")

max_seq_len = 25

def generate_text(seed_text, next_words=30, temperature=0.5):
    seed_text = seed_text.lower() + ' | '
    hate_speech = detect_hate_speech(seed_text)
    profanity = detect_profanity([seed_text], language='All')

    if profanity:
        gr.Warning("Profanity detected in the prompt, using the default prompt.")
        seed_text = 'game name | '
    elif hate_speech and hate_speech[0] in ['Hate Speech', 'Offensive Speech']:
        gr.Warning('Harmful speech detected in the seed text, using default prompt.')
        seed_text = 'game name | '

    generated_text = seed_text
    for _ in range(next_words):
        token_list = sp.encode_as_ids(generated_text)
        token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
        predicted = model.predict(token_list, verbose=0)[0]

        predicted = np.asarray(predicted).astype("float64")
        predicted = np.log(predicted + 1e-8) / temperature
        exp_preds = np.exp(predicted)
        predicted = exp_preds / np.sum(exp_preds)

        next_index = np.random.choice(len(predicted), p=predicted)
        next_token = sp.id_to_piece(next_index)
        generated_text += next_token

        if next_token.endswith('</s>') or next_token.endswith('<unk>'):
            break

    decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
    decoded = decoded.replace("</s>", "").replace("<unk>", "").strip()

    # Remove the prompt from the generated text
    if '|' in decoded:
        decoded = decoded.split('|', 1)[1].strip() #Split at the first occurence of '|' and take the second part

    hate_speech2 = detect_hate_speech(decoded)
    profanity2 = detect_profanity([decoded], language='All')

    if profanity2 or (hate_speech2 and hate_speech2[0] in ['Hate Speech', 'Offensive Speech']):
        gr.Warning("Flagged potentially harmful output.")
        decoded = 'Flagged Output'

    return decoded

demo = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(label="Prompt", value="a female character name", max_lines=1),
        gr.Slider(1, 50, step=1, label='Next Words', value=30),
        gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic')
    ],
    outputs=gr.Textbox(label="Generated Names"),
    title='Dungen Dev - Name Generator',
    description='A prompt-based name generator for game developers.',
    examples=[
        ["a male elf name", 30, 0.5],
        ["a futuristic city name", 30, 0.5],
        ["a powerful magic item name", 30, 0.5],
        ["a dark and mysterious forest name", 30, 0.5],
        ["a female character name", 30, 0.5]
    ]
)

demo.launch()