Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import gradio as gr | |
| import sentencepiece as spm | |
| import numpy as np | |
| import tensorflow as tf | |
| def custom_pad_sequences(sequences, maxlen, padding='pre', value=0): | |
| """ | |
| Pads sequences to the same length. | |
| :param sequences: List of lists, where each element is a sequence. | |
| :param maxlen: Maximum length of all sequences. | |
| :param padding: 'pre' or 'post', pad either before or after each sequence. | |
| :param value: Float, padding value. | |
| :return: Numpy array with dimensions (number_of_sequences, maxlen) | |
| """ | |
| padded_sequences = np.full((len(sequences), maxlen), value) | |
| for i, seq in enumerate(sequences): | |
| if padding == 'pre': | |
| if len(seq) <= maxlen: | |
| padded_sequences[i, -len(seq):] = seq | |
| else: | |
| padded_sequences[i, :] = seq[-maxlen:] | |
| elif padding == 'post': | |
| if len(seq) <= maxlen: | |
| padded_sequences[i, :len(seq)] = seq | |
| else: | |
| padded_sequences[i, :] = seq[:maxlen] | |
| return padded_sequences | |
| def generate_random_name(interpreter, vocab_size, sp, max_length=10, temperature=0.5, seed_text=""): | |
| # Get input and output tensors | |
| input_details = interpreter.get_input_details() | |
| output_details = interpreter.get_output_details() | |
| if seed_text: | |
| generated_name = seed_text | |
| else: | |
| random_index = np.random.randint(1, vocab_size) | |
| random_token = sp.id_to_piece(random_index) | |
| generated_name = random_token | |
| for _ in range(max_length - 1): | |
| token_list = sp.encode_as_ids(generated_name) | |
| # Pad to the correct length expected by the model | |
| token_list = custom_pad_sequences([token_list], maxlen=max_seq_len, padding='pre') | |
| # Convert token_list to FLOAT32 before setting the tensor | |
| token_list = token_list.astype(np.float32) | |
| # Set the input tensor | |
| interpreter.set_tensor(input_details[0]['index'], token_list) | |
| # Run inference | |
| interpreter.invoke() | |
| # Get the output tensor | |
| predicted = interpreter.get_tensor(output_details[0]['index'])[0] | |
| # Apply temperature to predictions | |
| predicted = np.log(predicted + 1e-8) / temperature | |
| predicted = np.exp(predicted) / np.sum(np.exp(predicted)) | |
| # Sample from the distribution | |
| next_index = np.random.choice(range(vocab_size), p=predicted) | |
| next_index = int(next_index) | |
| next_token = sp.id_to_piece(next_index) | |
| generated_name += next_token | |
| # Decode the generated subword tokens into a string | |
| decoded_name = sp.decode_pieces(generated_name.split()) | |
| # Stop if end token is predicted (optional) | |
| if next_token == '' or len(decoded_name) > max_length: | |
| break | |
| decoded_name = decoded_name.replace("β", " ") | |
| decoded_name = decoded_name.replace("</s>", "") | |
| generated_name = decoded_name.rsplit(' ', 1)[0] | |
| generated_name = generated_name[0].upper() + generated_name[1:] | |
| # Split the name and check the last part | |
| parts = generated_name.split() | |
| if parts and len(parts[-1]) < 3: | |
| generated_name = " ".join(parts[:-1]) | |
| return generated_name.strip() | |
| def generateTerrariaNames(amount, max_length=30, temperature=0.5, seed_text=""): | |
| sp = spm.SentencePieceProcessor() | |
| sp.load("models/terraria_names.model") | |
| amount = int(amount) | |
| names = [] | |
| # Define necessary variables | |
| vocab_size = sp.GetPieceSize() | |
| max_seq_len = 12 # For skyrim = 13, for terraria = 12 | |
| # Load TFLite model | |
| interpreter = tf.lite.Interpreter(model_path="models/dungen_terraria_model.tflite") | |
| interpreter.allocate_tensors() | |
| # Use the function to generate a name | |
| # Assuming `vocab_size` and `sp` (SentencePiece processor) are defined elsewhere | |
| for _ in range(amount): | |
| generated_name = generate_random_name(interpreter, vocab_size, sp, seed_text=seed_text, max_length=max_length, temperature=temperature) | |
| names.append(generated_name) | |
| return names | |
| demo = gr.Interface( | |
| fn=generateTerrariaNames, | |
| inputs=[gr.Number(1,25), gr.Slider(10, 60), gr.Slider(0.01, 1), gr.Text(0,10)], | |
| outputs=["text"], | |
| ) | |
| demo.launch() |