Spaces:
Running
Running
# import sentencepiece as spm | |
# import numpy as np | |
# import tensorflow as tf | |
# from tensorflow.keras.preprocessing.sequence import pad_sequences | |
# from valx import detect_profanity, detect_hate_speech | |
# import gradio as gr | |
# sp = spm.SentencePieceProcessor() | |
# sp.Load("dungen_dev_preview.model") | |
# model = tf.keras.models.load_model("dungen_dev_preview_model.keras") | |
# max_seq_len = 25 | |
# def generate_text(seed_text, next_words=30, temperature=0.5): | |
# seed_text = seed_text.strip().lower() | |
# if "|" in seed_text: | |
# gr.Warning("The prompt should not contain the '|' character. Using default prompt.") | |
# seed_text = 'game name | ' | |
# elif detect_profanity([seed_text], language='All'): | |
# gr.Warning("Profanity detected in the prompt, using the default prompt.") | |
# seed_text = 'game name | ' | |
# elif (hate_speech_result := detect_hate_speech(seed_text)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: | |
# gr.Warning('Harmful speech detected in the prompt, using default prompt.') | |
# seed_text = 'game name | ' | |
# else: | |
# seed_text += ' | ' | |
# generated_text = seed_text | |
# if generated_text != 'game name | ': # only generate if not the default prompt | |
# for _ in range(next_words): | |
# token_list = sp.encode_as_ids(generated_text) | |
# token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre') | |
# predicted = model.predict(token_list, verbose=0)[0] | |
# predicted = np.asarray(predicted).astype("float64") | |
# predicted = np.log(predicted + 1e-8) / temperature | |
# exp_preds = np.exp(predicted) | |
# predicted = exp_preds / np.sum(exp_preds) | |
# next_index = np.random.choice(len(predicted), p=predicted) | |
# next_token = sp.id_to_piece(next_index) | |
# generated_text += next_token | |
# if next_token.endswith('</s>') or next_token.endswith('<unk>'): | |
# break | |
# decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text)) | |
# decoded = decoded.replace("</s>", "").replace("<unk>", "").strip() | |
# if '|' in decoded: | |
# decoded = decoded.split('|', 1)[1].strip() | |
# if any(detect_profanity([decoded], language='All')) or (hate_speech_result := detect_hate_speech(decoded)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: | |
# gr.Warning("Flagged potentially harmful output.") | |
# decoded = 'Flagged Output' | |
# return decoded | |
# demo = gr.Interface( | |
# fn=generate_text, | |
# inputs=[ | |
# gr.Textbox(label="Prompt", value="a female character name", max_lines=1), | |
# gr.Slider(1, 100, step=1, label='Next Words', value=30), | |
# gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic') | |
# ], | |
# outputs=gr.Textbox(label="Generated Names"), | |
# title='Dungen Dev - Name Generator', | |
# description='A prompt-based name generator for game developers. Dungen Dev is an experimental model, and may produce outputs that are inappropriate, biased, or potentially harmful and inaccurate. Caution is advised.', | |
# examples=[ | |
# ["a male character name", 30, 0.5], | |
# ["a futuristic city name", 30, 0.5], | |
# ["an item name", 30, 0.5], | |
# ["a dark and mysterious forest name", 30, 0.5], | |
# ["an evil character name", 30, 0.5] | |
# ] | |
# ) | |
# demo.launch() | |
import sentencepiece as spm | |
import numpy as np | |
import tensorflow as tf | |
from tensorflow.keras.preprocessing.sequence import pad_sequences | |
from valx import detect_profanity, detect_hate_speech | |
import gradio as gr | |
import csv | |
from datetime import datetime | |
sp = spm.SentencePieceProcessor() | |
sp.Load("dungen_dev_preview.model") | |
model = tf.keras.models.load_model("dungen_dev_preview_model.keras") | |
max_seq_len = 25 | |
def generate_text(seed_text, next_words=30, temperature=0.5): | |
seed_text = seed_text.strip().lower() | |
if "|" in seed_text: | |
gr.Warning("The prompt should not contain the '|' character. Using default prompt.") | |
seed_text = 'game name | ' | |
elif detect_profanity([seed_text], language='All'): | |
gr.Warning("Profanity detected in the prompt, using the default prompt.") | |
seed_text = 'game name | ' | |
elif (hate_speech_result := detect_hate_speech(seed_text)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: | |
gr.Warning('Harmful speech detected in the prompt, using default prompt.') | |
seed_text = 'game name | ' | |
else: | |
seed_text += ' | ' | |
generated_text = seed_text | |
if generated_text != 'game name | ': # only generate if not the default prompt | |
for _ in range(next_words): | |
token_list = sp.encode_as_ids(generated_text) | |
token_list = pad_sequences([token_list], maxlen=max_seq_len - 1, padding='pre') | |
predicted = model.predict(token_list, verbose=0)[0] | |
predicted = np.asarray(predicted).astype("float64") | |
predicted = np.log(predicted + 1e-8) / temperature | |
exp_preds = np.exp(predicted) | |
predicted = exp_preds / np.sum(exp_preds) | |
next_index = np.random.choice(len(predicted), p=predicted) | |
next_token = sp.id_to_piece(next_index) | |
generated_text += next_token | |
if next_token.endswith('</s>') or next_token.endswith('<unk>'): | |
break | |
decoded = sp.decode_pieces(sp.encode_as_pieces(generated_text)) | |
decoded = decoded.replace("</s>", "").replace("<unk>", "").strip() | |
if '|' in decoded: | |
decoded = decoded.split('|', 1)[1].strip() | |
if any(detect_profanity([decoded], language='All')) or (hate_speech_result := detect_hate_speech(decoded)) and hate_speech_result[0] in ['Hate Speech', 'Offensive Speech']: | |
gr.Warning("Flagged potentially harmful output.") | |
decoded = 'Flagged Output' | |
return decoded | |
flagged_outputs = [] | |
def flag_output(prompt, generated_text, next_words, temperature): | |
if not generated_text.strip(): | |
return "Cannot flag an empty output." | |
timestamp = datetime.now().isoformat() | |
flagged_outputs.append({ | |
"Prompt": prompt, | |
"Generated Text": generated_text, | |
"Next Words": next_words, | |
"Temperature": temperature, | |
"Timestamp": timestamp | |
}) | |
with open("flagged_outputs.csv", "a", newline="") as file: | |
writer = csv.DictWriter(file, fieldnames=["Prompt", "Generated Text", "Next Words", "Temperature", "Timestamp"]) | |
if file.tell() == 0: | |
writer.writeheader() | |
writer.writerow({ | |
"Prompt": prompt, | |
"Generated Text": generated_text, | |
"Next Words": next_words, | |
"Temperature": temperature, | |
"Timestamp": timestamp | |
}) | |
return "Output flagged successfully." | |
demo = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(label="Prompt", value="a female character name", max_lines=1), | |
gr.Slider(1, 100, step=1, label='Next Words', value=30), | |
gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probabilistic') | |
], | |
outputs=[ | |
gr.Textbox(label="Generated Name"), | |
gr.Button("Flag Output") | |
], | |
title='Dungen Dev - Name Generator', | |
description='A prompt-based name generator for game developers. Dungen Dev is an experimental model, and may produce outputs that are inappropriate, biased, or potentially harmful and inaccurate. Caution is advised.', | |
examples=[ | |
["a male character name", 30, 0.5], | |
["a futuristic city name", 30, 0.5], | |
["an item name", 30, 0.5], | |
["a dark and mysterious forest name", 30, 0.5], | |
["an evil character name", 30, 0.5] | |
] | |
) | |
demo.launch() |