Spaces:
Runtime error
Runtime error
File size: 5,635 Bytes
0e403da 1d30073 7bbddfb 0e403da 7bbddfb 0e403da 7bbddfb 0e403da 7bbddfb 0e403da 7bbddfb 0e403da 8e1a8c8 5f81dcb 943ee2f 7bbddfb 943ee2f 7bbddfb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
import streamlit as st
import jax.numpy as jnp
from transformers import AutoTokenizer
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
from t5_vae_flax_alt.src.t5_vae import FlaxT5VaeForAutoencoding
import info
st.set_page_config(
page_title="T5-VAE",
page_icon="πππ",
layout="wide",
initial_sidebar_state="expanded"
)
st.title('T5-VAE πππ')
st.text('''
This is a variational autoencoder trained on text.
It allows interpolating on text at a high level, try it out!
See how it works [here](http://fras.uk/ml/large%20prior-free%20models/transformer-vae/2020/08/13/Transformers-as-Variational-Autoencoders.html).
''')
st.text('''
Try interpolating between lines of Python code using this T5-VAE.
''')
@st.cache(allow_output_mutation=True)
def get_model():
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-python")
assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
return model, tokenizer
model, tokenizer = get_model()
def add_decoder_input_ids(examples):
arr_input_ids = jnp.array(examples["input_ids"])
pad = tokenizer.pad_token_id * jnp.ones((arr_input_ids.shape[0], 1), dtype=jnp.int32)
arr_pad_input_ids = jnp.concatenate((arr_input_ids, pad), axis=1)
examples['decoder_input_ids'] = shift_tokens_right(arr_pad_input_ids, tokenizer.pad_token_id, model.config.decoder_start_token_id)
arr_attention_mask = jnp.array(examples['attention_mask'])
ones = jnp.ones((arr_attention_mask.shape[0], 1), dtype=jnp.int32)
examples['decoder_attention_mask'] = jnp.concatenate((ones, arr_attention_mask), axis=1)
for k in ['decoder_input_ids', 'decoder_attention_mask']:
examples[k] = examples[k].tolist()
return examples
def prepare_inputs(inputs):
for k, v in inputs.items():
inputs[k] = jnp.array(v)
return add_decoder_input_ids(inputs)
def get_latent(text):
return model(**prepare_inputs(tokenizer([text]))).latent_codes[0]
def tokens_from_latent(latent_codes):
model.config.is_encoder_decoder = True
output_ids = model.generate(
latent_codes=jnp.array([latent_codes]),
bos_token_id=model.config.decoder_start_token_id,
min_length=1,
max_length=32,
)
return output_ids
def slerp(ratio, t1, t2):
'''
Perform a spherical interpolation between 2 vectors.
Most of the volume of a high-dimensional orange is in the skin, not the pulp.
This also applies for multivariate Gaussian distributions.
To that end we can interpolate between samples by following the surface of a n-dimensional sphere rather than a straight line.
Args:
ratio: Interpolation ratio.
t1: Tensor1
t2: Tensor2
'''
low_norm = t1 / jnp.linalg.norm(t1, axis=1, keepdims=True)
high_norm = t2 / jnp.linalg.norm(t2, axis=1, keepdims=True)
omega = jnp.arccos((low_norm * high_norm).sum(1))
so = jnp.sin(omega)
res = (jnp.sin((1.0 - ratio) * omega) / so)[0] * t1 + (jnp.sin(ratio * omega) / so)[0] * t2
return res
def decode(cnt, ratio, txt_1, txt_2):
if not txt_1 or not txt_2:
return ''
cnt.write('Getting latents...')
lt_1, lt_2 = get_latent(txt_1), get_latent(txt_2)
lt_new = slerp(ratio, lt_1, lt_2)
cnt.write('Decoding latent...')
tkns = tokens_from_latent(lt_new)
return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
in_1 = st.text_input("A line of Python code.", "x = a - 1")
in_2 = st.text_input("Another line of Python code.", "x = a + 10 * 2")
r = st.slider('Interpolation Ratio', min_value=0.0, max_value=1.0, value=0.5)
container = st.empty()
container.write('Loading...')
out = decode(container, r, in_1, in_2)
container.empty()
st.write(out)
st.text('''
Try interpolating between sentences from wikipedia using this T5-VAE.
''')
@st.cache(allow_output_mutation=True)
def get_wiki_model():
tokenizer = AutoTokenizer.from_pretrained("t5-base")
model = FlaxT5VaeForAutoencoding.from_pretrained("flax-community/t5-vae-wiki")
assert model.params['t5']['shared']['embedding'].shape[0] == len(tokenizer), "T5 Tokenizer doesn't match T5Vae embedding size."
return model, tokenizer
model, tokenizer = get_wiki_model()
in_1 = st.text_input("A sentence.", "Children are looking for the water to be clear.")
in_2 = st.text_input("Another sentence.", "There are two people playing soccer.")
r = st.slider('Interpolation Ratio', min_value=0.0, max_value=1.0, value=0.5)
container = st.empty()
container.write('Loading...')
out = decode(r, in_1, in_2)
container.empty()
st.write(out)
st.text('''
Try arithmetic in latent space.
''')
def arithmetic(cnt, txt_a, txt_b, txt_c):
if not txt_a or not txt_b or not txt_c:
return ''
cnt.write('getting latents...')
lt_a, lt_b, lt_c = get_latent(txt_a), get_latent(txt_b), get_latent(txt_c)
lt_d = lt_c + (lt_b - lt_a)
cnt.write('decoding C + (B - A)...')
tkns = tokens_from_latent(lt_d)
return tokenizer.decode(tkns.sequences[0], skip_special_tokens=True)
in_a = st.text_input("A", "Children are looking for the water to be clear.")
in_b = st.text_input("B", "There are two people playing soccer.")
in_c = st.text_input("C", "Children are looking for the water to be clear.")
st.text('''
A is to B as C is to...
''')
container = st.empty()
container.write('Loading...')
out = arithmetic(container, in_a, in_b, in_c)
container.empty()
st.write(out)
|