Dungen / app.py
Infinitode Pty Ltd
Update app.py
c7c3fb8 verified
raw
history blame
4.72 kB
import os
import random
import gradio as gr
import sentencepiece as spm
import numpy as np
import tensorflow as tf
max_seq_len = 12 # For skyrim = 13, for terraria = 12
def custom_pad_sequences(sequences, maxlen, padding='pre', value=0):
"""
Pads sequences to the same length.
:param sequences: List of lists, where each element is a sequence.
:param maxlen: Maximum length of all sequences.
:param padding: 'pre' or 'post', pad either before or after each sequence.
:param value: Float, padding value.
:return: Numpy array with dimensions (number_of_sequences, maxlen)
"""
maxlen = max_seq_len
padded_sequences = np.full((len(sequences), maxlen), value)
for i, seq in enumerate(sequences):
if padding == 'pre':
if len(seq) <= maxlen:
padded_sequences[i, -len(seq):] = seq
else:
padded_sequences[i, :] = seq[-maxlen:]
elif padding == 'post':
if len(seq) <= maxlen:
padded_sequences[i, :len(seq)] = seq
else:
padded_sequences[i, :] = seq[:maxlen]
return padded_sequences
def generate_random_name(interpreter, vocab_size, sp, max_length=10, temperature=0.5, seed_text=""):
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
if seed_text:
generated_name = seed_text
else:
random_index = np.random.randint(1, vocab_size)
random_token = sp.id_to_piece(random_index)
generated_name = random_token
for _ in range(max_length - 1):
token_list = sp.encode_as_ids(generated_name)
# Pad to the correct length expected by the model
token_list = custom_pad_sequences([token_list], maxlen=max_seq_len, padding='pre')
# Convert token_list to FLOAT32 before setting the tensor
token_list = token_list.astype(np.float32)
# Set the input tensor
interpreter.set_tensor(input_details[0]['index'], token_list)
# Run inference
interpreter.invoke()
# Get the output tensor
predicted = interpreter.get_tensor(output_details[0]['index'])[0]
# Apply temperature to predictions
predicted = np.log(predicted + 1e-8) / temperature
predicted = np.exp(predicted) / np.sum(np.exp(predicted))
# Sample from the distribution
next_index = np.random.choice(range(vocab_size), p=predicted)
next_index = int(next_index)
next_token = sp.id_to_piece(next_index)
generated_name += next_token
# Decode the generated subword tokens into a string
decoded_name = sp.decode_pieces(generated_name.split())
# Stop if end token is predicted (optional)
if next_token == '' or len(decoded_name) > max_length:
break
decoded_name = decoded_name.replace("▁", " ")
decoded_name = decoded_name.replace("</s>", "")
generated_name = decoded_name.rsplit(' ', 1)[0]
generated_name = generated_name[0].upper() + generated_name[1:]
# Split the name and check the last part
parts = generated_name.split()
if parts and len(parts[-1]) < 3:
generated_name = " ".join(parts[:-1])
return generated_name.strip()
def generateTerrariaNames(amount, max_length=30, temperature=0.5, seed_text=""):
sp = spm.SentencePieceProcessor()
sp.load("models/terraria_names.model")
amount = int(amount)
max_length = int(max_length)
names = []
# Define necessary variables
vocab_size = sp.GetPieceSize()
# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="models/dungen_terraria_model.tflite")
interpreter.allocate_tensors()
# Use the function to generate a name
# Assuming `vocab_size` and `sp` (SentencePiece processor) are defined elsewhere
for _ in range(amount):
generated_name = generate_random_name(interpreter, vocab_size, sp, seed_text=seed_text, max_length=max_length, temperature=temperature)
names.append(generated_name)
return names
demo = gr.Interface(
fn=generateTerrariaNames,
inputs=[gr.Slider(1,25, step=1, label='Amount of Names', info='How many names to generate, must be greater than 0'), gr.Slider(10, 60, value=30, step=1, label='Max Length', info='Max length of the generated word'), gr.Slider(0.1, 1, value=0.5, label='Temperature', info='Controls randomness of generation, higher values = more creative, lower values = more probalistic'), gr.Text(0,10, label='Seed text (optional)', value='', info='The starting text to begin with')],
outputs=["text"],
title='Dungen - Name Generator'
)
demo.launch()