Spaces:
Running
Running
File size: 4,327 Bytes
bd71161 53200cd bd71161 03980c2 bd71161 b8a7d0a bd71161 53200cd bd71161 575628b bd71161 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import random
import gradio as gr
import sentencepiece as spm
import numpy as np
import tensorflow as tf
def custom_pad_sequences(sequences, maxlen, padding='pre', value=0):
"""
Pads sequences to the same length.
:param sequences: List of lists, where each element is a sequence.
:param maxlen: Maximum length of all sequences.
:param padding: 'pre' or 'post', pad either before or after each sequence.
:param value: Float, padding value.
:return: Numpy array with dimensions (number_of_sequences, maxlen)
"""
max_seq_len = 12 # For skyrim = 13, for terraria = 12
padded_sequences = np.full((len(sequences), maxlen), value)
for i, seq in enumerate(sequences):
if padding == 'pre':
if len(seq) <= maxlen:
padded_sequences[i, -len(seq):] = seq
else:
padded_sequences[i, :] = seq[-maxlen:]
elif padding == 'post':
if len(seq) <= maxlen:
padded_sequences[i, :len(seq)] = seq
else:
padded_sequences[i, :] = seq[:maxlen]
return padded_sequences
def generate_random_name(interpreter, vocab_size, sp, max_length=10, temperature=0.5, seed_text=""):
# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
if seed_text:
generated_name = seed_text
else:
random_index = np.random.randint(1, vocab_size)
random_token = sp.id_to_piece(random_index)
generated_name = random_token
for _ in range(max_length - 1):
token_list = sp.encode_as_ids(generated_name)
# Pad to the correct length expected by the model
token_list = custom_pad_sequences([token_list], maxlen=max_seq_len, padding='pre')
# Convert token_list to FLOAT32 before setting the tensor
token_list = token_list.astype(np.float32)
# Set the input tensor
interpreter.set_tensor(input_details[0]['index'], token_list)
# Run inference
interpreter.invoke()
# Get the output tensor
predicted = interpreter.get_tensor(output_details[0]['index'])[0]
# Apply temperature to predictions
predicted = np.log(predicted + 1e-8) / temperature
predicted = np.exp(predicted) / np.sum(np.exp(predicted))
# Sample from the distribution
next_index = np.random.choice(range(vocab_size), p=predicted)
next_index = int(next_index)
next_token = sp.id_to_piece(next_index)
generated_name += next_token
# Decode the generated subword tokens into a string
decoded_name = sp.decode_pieces(generated_name.split())
# Stop if end token is predicted (optional)
if next_token == '' or len(decoded_name) > max_length:
break
decoded_name = decoded_name.replace("▁", " ")
decoded_name = decoded_name.replace("</s>", "")
generated_name = decoded_name.rsplit(' ', 1)[0]
generated_name = generated_name[0].upper() + generated_name[1:]
# Split the name and check the last part
parts = generated_name.split()
if parts and len(parts[-1]) < 3:
generated_name = " ".join(parts[:-1])
return generated_name.strip()
def generateTerrariaNames(amount, max_length=30, temperature=0.5, seed_text=""):
sp = spm.SentencePieceProcessor()
sp.load("models/terraria_names.model")
amount = int(amount)
max_length = int(max_length)
names = []
# Define necessary variables
vocab_size = sp.GetPieceSize()
max_seq_len = 12 # For skyrim = 13, for terraria = 12
# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="models/dungen_terraria_model.tflite")
interpreter.allocate_tensors()
# Use the function to generate a name
# Assuming `vocab_size` and `sp` (SentencePiece processor) are defined elsewhere
for _ in range(amount):
generated_name = generate_random_name(interpreter, vocab_size, sp, seed_text=seed_text, max_length=max_length, temperature=temperature)
names.append(generated_name)
return names
demo = gr.Interface(
fn=generateTerrariaNames,
inputs=[gr.Number(1,25), gr.Slider(10, 60), gr.Slider(0.01, 1), gr.Text(0,10)],
outputs=["text"],
)
demo.launch() |