Spaces:
Running
Running
File size: 3,393 Bytes
da27f68 610b4ea feac64d 610b4ea 261d6fa feac64d da27f68 261d6fa feac64d 610b4ea feac64d 610b4ea da27f68 feac64d da27f68 feac64d da27f68 feac64d da27f68 feac64d da27f68 664181d feac64d 664181d 261d6fa 610b4ea da27f68 610b4ea feac64d 610b4ea 664181d da27f68 261d6fa 664181d 238d2c6 78c5ea2 238d2c6 78c5ea2 238d2c6 664181d da27f68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import sentencepiece as spm
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from valx import detect_profanity, detect_hate_speech
import gradio as gr
sp = spm.SentencePieceProcessor()
sp.Load("dungen_dev_preview.model")
model = tf.keras.models.load_model("dungen_dev_preview_model.keras")
max_seq_len = 25
def generate_text(seed_text, next_words=30, temperature=0.5):
seed_text = seed_text.strip().lower()
if "|" in seed_text:
gr.Warning("The prompt should not contain the '|' character. Using default prompt.")
seed_text = 'game name | '
elif detect_profanity([seed_text], language='All'):
gr.Warning("Profanity detected in the prompt, using the default prompt.")
seed_text = 'game name | '
elif (hate_speech_result := detect_hate_speech(seed_text)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning('Harmful speech detected in the prompt, using default prompt.')
seed_text = 'game name | '
else:
seed_text += ' | '
generated_text = seed_text
if generated_text != 'game name | ': # only generate if not the default prompt
for _ in range(next_words):
token_list = sp.encode_as_ids(generated_text)
token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
predicted = model.predict(token_list, verbose=0)[0]
predicted = np.asarray(predicted).astype("float64")
predicted = np.log(predicted + 1e-8) / temperature
exp_preds = np.exp(predicted)
predicted = exp_preds / np.sum(exp_preds)
next_index = np.random.choice(len(predicted), p=predicted)
next_token = sp.id_to_piece(next_index)
generated_text += next_token
if next_token.endswith('</s>') or next_token.endswith('<unk>'):
break
decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
decoded = decoded.replace("</s>", "").replace("<unk>", "").strip()
if '|' in decoded:
decoded = decoded.split('|', 1)[1].strip()
if any(detect_profanity([decoded], language='All')) or (hate_speech_result := detect_hate_speech(decoded)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning("Flagged potentially harmful output.")
decoded = 'Flagged Output'
return decoded
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", value="a female character name", max_lines=1),
gr.Slider(1, 100, step=1, label='Next Words', value=30),
gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic')
],
outputs=gr.Textbox(label="Generated Names"),
title='Dungen Dev - Name Generator',
description='A prompt-based name generator for game developers. Dungen Dev is an experimental model, and may produce outputs that are inappropriate, biased, or potentially harmful and inaccurate. Caution is advised.',
examples=[
["a male character name", 30, 0.5],
["a futuristic city name", 30, 0.5],
["an item name", 30, 0.5],
["a dark and mysterious forest name", 30, 0.5],
["an evil character name", 30, 0.5]
]
)
demo.launch() |