Spaces:
Running
Running
import os | |
import random | |
import gradio as gr | |
import sentencepiece as spm | |
import numpy as np | |
import tensorflow as tf | |
max_seq_len = 12 # For skyrim = 13, for terraria = 12 | |
def custom_pad_sequences(sequences, maxlen, padding='pre', value=0): | |
""" | |
Pads sequences to the same length. | |
:param sequences: List of lists, where each element is a sequence. | |
:param maxlen: Maximum length of all sequences. | |
:param padding: 'pre' or 'post', pad either before or after each sequence. | |
:param value: Float, padding value. | |
:return: Numpy array with dimensions (number_of_sequences, maxlen) | |
""" | |
maxlen = max_seq_len | |
padded_sequences = np.full((len(sequences), maxlen), value) | |
for i, seq in enumerate(sequences): | |
if padding == 'pre': | |
if len(seq) <= maxlen: | |
padded_sequences[i, -len(seq):] = seq | |
else: | |
padded_sequences[i, :] = seq[-maxlen:] | |
elif padding == 'post': | |
if len(seq) <= maxlen: | |
padded_sequences[i, :len(seq)] = seq | |
else: | |
padded_sequences[i, :] = seq[:maxlen] | |
return padded_sequences | |
def generate_random_name(interpreter, vocab_size, sp, max_length=10, temperature=0.5, seed_text=""): | |
# Get input and output tensors | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
if seed_text: | |
generated_name = seed_text | |
else: | |
random_index = np.random.randint(1, vocab_size) | |
random_token = sp.id_to_piece(random_index) | |
generated_name = random_token | |
for _ in range(max_length - 1): | |
token_list = sp.encode_as_ids(generated_name) | |
# Pad to the correct length expected by the model | |
token_list = custom_pad_sequences([token_list], maxlen=max_seq_len, padding='pre') | |
# Convert token_list to FLOAT32 before setting the tensor | |
token_list = token_list.astype(np.float32) | |
# Set the input tensor | |
interpreter.set_tensor(input_details[0]['index'], token_list) | |
# Run inference | |
interpreter.invoke() | |
# Get the output tensor | |
predicted = interpreter.get_tensor(output_details[0]['index'])[0] | |
# Apply temperature to predictions | |
predicted = np.log(predicted + 1e-8) / temperature | |
predicted = np.exp(predicted) / np.sum(np.exp(predicted)) | |
# Sample from the distribution | |
next_index = np.random.choice(range(vocab_size), p=predicted) | |
next_index = int(next_index) | |
next_token = sp.id_to_piece(next_index) | |
generated_name += next_token | |
# Decode the generated subword tokens into a string | |
decoded_name = sp.decode_pieces(generated_name.split()) | |
# Stop if end token is predicted (optional) | |
if next_token == '' or len(decoded_name) > max_length: | |
break | |
decoded_name = decoded_name.replace("β", " ") | |
decoded_name = decoded_name.replace("</s>", "") | |
generated_name = decoded_name.rsplit(' ', 1)[0] | |
generated_name = generated_name[0].upper() + generated_name[1:] | |
# Split the name and check the last part | |
parts = generated_name.split() | |
if parts and len(parts[-1]) < 3: | |
generated_name = " ".join(parts[:-1]) | |
return generated_name.strip() | |
def generateTerrariaNames(amount, max_length=30, temperature=0.5, seed_text=""): | |
sp = spm.SentencePieceProcessor() | |
sp.load("models/terraria_names.model") | |
amount = int(amount) | |
max_length = int(max_length) | |
names = [] | |
# Define necessary variables | |
vocab_size = sp.GetPieceSize() | |
# Load TFLite model | |
interpreter = tf.lite.Interpreter(model_path="models/dungen_terraria_model.tflite") | |
interpreter.allocate_tensors() | |
# Use the function to generate a name | |
# Assuming `vocab_size` and `sp` (SentencePiece processor) are defined elsewhere | |
for _ in range(amount): | |
generated_name = generate_random_name(interpreter, vocab_size, sp, seed_text=seed_text, max_length=max_length, temperature=temperature) | |
names.append(generated_name) | |
return names | |
demo = gr.Interface( | |
fn=generateTerrariaNames, | |
inputs=[gr.Slider(1,25, step=1, label='Amount of Names', info='How many names to generate, must be greater than 0'), gr.Slider(10, 60, value=30, step=1, label='Max Length', info='Max length of the generated word'), gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic'), gr.Text(0,10, label='Seed text (optional)', value='', info='The starting text to begin with')], | |
outputs=["text"], | |
title='Dungen - Name Generator' | |
) | |
demo.launch() |