Dungen-Dev / app.py
JohanBeytell's picture
Update app.py
afba42d verified
raw
history blame
7.95 kB
# import sentencepiece as spm
# import numpy as np
# import tensorflow as tf
# from tensorflow.keras.preprocessing.sequence import pad_sequences
# from valx import detect_profanity, detect_hate_speech
# import gradio as gr
# sp = spm.SentencePieceProcessor()
# sp.Load("dungen_dev_preview.model")
# model = tf.keras.models.load_model("dungen_dev_preview_model.keras")
# max_seq_len = 25
# def generate_text(seed_text, next_words=30, temperature=0.5):
# seed_text = seed_text.strip().lower()
# if "|" in seed_text:
# gr.Warning("The prompt should not contain the '|' character. Using default prompt.")
# seed_text = 'game name | '
# elif detect_profanity([seed_text], language='All'):
# gr.Warning("Profanity detected in the prompt, using the default prompt.")
# seed_text = 'game name | '
# elif (hate_speech_result := detect_hate_speech(seed_text)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
# gr.Warning('Harmful speech detected in the prompt, using default prompt.')
# seed_text = 'game name | '
# else:
# seed_text += ' | '
# generated_text = seed_text
# if generated_text != 'game name | ': # only generate if not the default prompt
# for _ in range(next_words):
# token_list = sp.encode_as_ids(generated_text)
# token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
# predicted = model.predict(token_list, verbose=0)[0]
# predicted = np.asarray(predicted).astype("float64")
# predicted = np.log(predicted + 1e-8) / temperature
# exp_preds = np.exp(predicted)
# predicted = exp_preds / np.sum(exp_preds)
# next_index = np.random.choice(len(predicted), p=predicted)
# next_token = sp.id_to_piece(next_index)
# generated_text += next_token
# if next_token.endswith('</s>') or next_token.endswith('<unk>'):
# break
# decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
# decoded = decoded.replace("</s>", "").replace("<unk>", "").strip()
# if '|' in decoded:
# decoded = decoded.split('|', 1)[1].strip()
# if any(detect_profanity([decoded], language='All')) or (hate_speech_result := detect_hate_speech(decoded)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
# gr.Warning("Flagged potentially harmful output.")
# decoded = 'Flagged Output'
# return decoded
# demo = gr.Interface(
# fn=generate_text,
# inputs=[
# gr.Textbox(label="Prompt", value="a female character name", max_lines=1),
# gr.Slider(1, 100, step=1, label='Next Words', value=30),
# gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic')
# ],
# outputs=gr.Textbox(label="Generated Names"),
# title='Dungen Dev - Name Generator',
# description='A prompt-based name generator for game developers. Dungen Dev is an experimental model, and may produce outputs that are inappropriate, biased, or potentially harmful and inaccurate. Caution is advised.',
# examples=[
# ["a male character name", 30, 0.5],
# ["a futuristic city name", 30, 0.5],
# ["an item name", 30, 0.5],
# ["a dark and mysterious forest name", 30, 0.5],
# ["an evil character name", 30, 0.5]
# ]
# )
# demo.launch()
import sentencepiece as spm
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from valx import detect_profanity, detect_hate_speech
import gradio as gr
import csv
from datetime import datetime
sp = spm.SentencePieceProcessor()
sp.Load("dungen_dev_preview.model")
model = tf.keras.models.load_model("dungen_dev_preview_model.keras")
max_seq_len = 25
def generate_text(seed_text, next_words=30, temperature=0.5):
seed_text = seed_text.strip().lower()
if "|" in seed_text:
gr.Warning("The prompt should not contain the '|' character. Using default prompt.")
seed_text = 'game name | '
elif detect_profanity([seed_text], language='All'):
gr.Warning("Profanity detected in the prompt, using the default prompt.")
seed_text = 'game name | '
elif (hate_speech_result := detect_hate_speech(seed_text)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning('Harmful speech detected in the prompt, using default prompt.')
seed_text = 'game name | '
else:
seed_text += ' | '
generated_text = seed_text
if generated_text != 'game name | ': # only generate if not the default prompt
for _ in range(next_words):
token_list = sp.encode_as_ids(generated_text)
token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre')
predicted = model.predict(token_list, verbose=0)[0]
predicted = np.asarray(predicted).astype("float64")
predicted = np.log(predicted + 1e-8) / temperature
exp_preds = np.exp(predicted)
predicted = exp_preds / np.sum(exp_preds)
next_index = np.random.choice(len(predicted), p=predicted)
next_token = sp.id_to_piece(next_index)
generated_text += next_token
if next_token.endswith('</s>') or next_token.endswith('<unk>'):
break
decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text))
decoded = decoded.replace("</s>", "").replace("<unk>", "").strip()
if '|' in decoded:
decoded = decoded.split('|', 1)[1].strip()
if any(detect_profanity([decoded], language='All')) or (hate_speech_result := detect_hate_speech(decoded)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']:
gr.Warning("Flagged potentially harmful output.")
decoded = 'Flagged Output'
return decoded
flagged_outputs = []
def flag_output(prompt, generated_text, next_words, temperature):
if not generated_text.strip():
return "Cannot flag an empty output."
timestamp = datetime.now().isoformat()
flagged_outputs.append({
"Prompt": prompt,
"Generated Text": generated_text,
"Next Words": next_words,
"Temperature": temperature,
"Timestamp": timestamp
})
with open("flagged_outputs.csv", "a", newline="") as file:
writer = csv.DictWriter(file, fieldnames=["Prompt", "Generated Text", "Next Words", "Temperature", "Timestamp"])
if file.tell() == 0:
writer.writeheader()
writer.writerow({
"Prompt": prompt,
"Generated Text": generated_text,
"Next Words": next_words,
"Temperature": temperature,
"Timestamp": timestamp
})
return "Output flagged successfully."
demo = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(label="Prompt", value="a female character name", max_lines=1),
gr.Slider(1, 100, step=1, label='Next Words', value=30),
gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probabilistic')
],
outputs=[
gr.Textbox(label="Generated Name"),
gr.Button("Flag Output")
],
title='Dungen Dev - Name Generator',
description='A prompt-based name generator for game developers. Dungen Dev is an experimental model, and may produce outputs that are inappropriate, biased, or potentially harmful and inaccurate. Caution is advised.',
examples=[
["a male character name", 30, 0.5],
["a futuristic city name", 30, 0.5],
["an item name", 30, 0.5],
["a dark and mysterious forest name", 30, 0.5],
["an evil character name", 30, 0.5]
]
)
demo.launch()