Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import os | |
| from dataclasses import replace | |
| import numpy as np | |
| import pytest | |
| import torch | |
| from bytelatent.constants import BLT_DATA | |
| from bytelatent.data.data_types import Batch | |
| from bytelatent.data.ngram_processor import NgramProcessor | |
| from bytelatent.model.blt import ( | |
| ByteLatentTransformer, | |
| ByteLatentTransformerArgs, | |
| EmbeddingType, | |
| compute_hash_embeddings, | |
| create_global_transformer, | |
| create_local_decoder, | |
| create_local_encoder, | |
| cross_attn_mask, | |
| decoder_patch_ids_from_lengths, | |
| get_blt_input, | |
| init_embeddings, | |
| patch_ids_from_lengths, | |
| ) | |
| from bytelatent.model.latent_transformer import CrossAttention | |
| from bytelatent.model.utils import create_causal_mask | |
| from bytelatent.optim import OptimArgs, build_optimizer | |
| from bytelatent.tokenizers.constants import EOS_ID | |
| from bytelatent.train import compute_loss | |
| def batch_to_tensors_and_gpu(batch): | |
| x = torch.from_numpy(batch.x) | |
| y = torch.from_numpy(batch.y) | |
| mask = None if batch.mask is None else torch.from_numpy(batch.mask) | |
| patch_lengths = ( | |
| None if batch.patch_lengths is None else torch.from_numpy(batch.patch_lengths) | |
| ) | |
| ngram_ids = None if batch.ngram_ids is None else torch.from_numpy(batch.ngram_ids) | |
| if torch.cuda.is_available(): | |
| x = x.cuda() | |
| y = y.cuda() | |
| if mask is not None: | |
| mask = mask.cuda() | |
| if patch_lengths is not None: | |
| patch_lengths = patch_lengths.cuda() | |
| if ngram_ids is not None: | |
| ngram_ids = ngram_ids.cuda() | |
| return x, y, mask, patch_lengths, ngram_ids | |
| def fake_batch(): | |
| batch_dict = torch.load(os.path.join(BLT_DATA, "test_batch.pt"), weights_only=False) | |
| del batch_dict["x2"] | |
| del batch_dict["y2"] | |
| del batch_dict["src_names"] | |
| return Batch(**batch_dict) | |
| def create_args(cross_attention=False): | |
| transformer_args = ByteLatentTransformerArgs( | |
| # Base args provided | |
| n_heads=8, | |
| dim=512, | |
| vocab_size=260, | |
| # Additional args from command line | |
| dim_token=256, | |
| patch_size=6, | |
| patching_mode="space", | |
| tie_local_encoder_decoder_logits=False, | |
| patch_in_forward=False, | |
| max_encoder_seq_length=12288, | |
| pad_to_max_length=True, | |
| encoder_lm_loss=False, | |
| patching_threshold=3.1439168453216553, | |
| encoder_hash_byte_group_size=[4], | |
| encoder_hash_byte_group_vocab=50002, | |
| encoder_hash_byte_group_nb_functions=3, | |
| cross_attn_encoder=cross_attention, # True, | |
| cross_attn_decoder=cross_attention, # True, | |
| cross_attn_window_encoder=512, | |
| cross_attn_window_decoder=512, | |
| dim_local_encoder=256, | |
| dim_local_decoder=256, | |
| cross_attn_k=8, | |
| cross_attn_nheads=4, | |
| cross_attn_all_layers_decoder=True, | |
| cross_attn_all_layers_encoder=True, | |
| cross_attn_use_flex_attention=True, | |
| cross_attn_init_by_pooling=True, | |
| log_patch_lengths=True, | |
| non_linearity="swiglu", | |
| use_rope=True, | |
| recompute_fc1_out=False, | |
| recompute_fc3_out=False, | |
| recompute_attn=False, | |
| custom_bwd=False, | |
| layer_ckpt="none", | |
| use_local_encoder_transformer=True, | |
| init_use_gaussian=True, | |
| init_use_depth="current", | |
| attn_bias_type="block_causal", | |
| attn_impl="xformers", | |
| alpha_depth="disabled", | |
| max_length=256, | |
| local_attention_window_len=512, | |
| max_seqlen=12288, | |
| downsampling_by_pooling="max", | |
| eos_id=EOS_ID, | |
| ) | |
| return transformer_args | |
| class TestByteLatentTransformer: | |
| def test_local_encoder(self): | |
| args = create_args() | |
| device = torch.device("cuda") | |
| local_encoder = create_local_encoder(args).to(device) | |
| batch = fake_batch() | |
| tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) | |
| local_encoder_tokens, _, _ = get_blt_input( | |
| tokens=tokens, | |
| enforce_patch_size_multiple=False, | |
| nb_boe=0, | |
| patch_size=local_encoder.patch_size, | |
| boe_id=local_encoder.boe_id, | |
| ) | |
| patch_ids = patch_ids_from_lengths( | |
| patch_lengths, local_encoder_tokens.shape[-1] | |
| ) | |
| encoder_hash_tok_embedding = init_embeddings( | |
| args, | |
| EmbeddingType.HASH_TOK, | |
| local_encoder_dim=local_encoder.dim, | |
| encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, | |
| ).to(device) | |
| local_encoder_embeds = compute_hash_embeddings( | |
| local_encoder_tokens=local_encoder_tokens, | |
| local_encoder=local_encoder, | |
| encoder_hash_tok_embedding=encoder_hash_tok_embedding, | |
| encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, | |
| encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, | |
| encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, | |
| ) | |
| reference_path = os.path.join(BLT_DATA, "local_encoder_tokens.pt") | |
| reference_tokens = torch.load(reference_path).to(device) | |
| torch.testing.assert_close( | |
| local_encoder_tokens, | |
| reference_tokens, | |
| msg="Generated tokens don't match reference tokens", | |
| ) | |
| (h_encoder, h_cross), cache_encoder = local_encoder( | |
| tokens=local_encoder_tokens, | |
| embeds=local_encoder_embeds, | |
| patch_embeds=None, | |
| cross_mask=None, | |
| num_patches=patch_lengths.shape[1], | |
| patch_ids=patch_ids, | |
| ) | |
| assert h_encoder is not None | |
| assert h_cross is None | |
| assert cache_encoder is None | |
| expected_shape = ( | |
| local_encoder_tokens.shape[0], | |
| local_encoder_tokens.shape[1], | |
| local_encoder.dim, | |
| ) | |
| assert h_encoder.shape == expected_shape | |
| def test_local_encoder_cross_attention(self): | |
| args = create_args(cross_attention=True) | |
| device = torch.device("cuda") | |
| local_encoder = create_local_encoder(args).to(device) | |
| batch = fake_batch() | |
| tokens, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) | |
| local_encoder_tokens, _, _ = get_blt_input( | |
| tokens=tokens, | |
| enforce_patch_size_multiple=False, | |
| nb_boe=0, | |
| patch_size=local_encoder.patch_size, | |
| boe_id=local_encoder.boe_id, | |
| ) | |
| patch_ids = patch_ids_from_lengths( | |
| patch_lengths, local_encoder_tokens.shape[-1] | |
| ) | |
| encoder_hash_tok_embedding = init_embeddings( | |
| args, | |
| EmbeddingType.HASH_TOK, | |
| local_encoder_dim=local_encoder.dim, | |
| encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, | |
| ).to(device) | |
| cross_attn_mask_enc = cross_attn_mask( | |
| patch_ids, | |
| patch_lengths, | |
| local_encoder_tokens.shape[-1], | |
| patches_as_queries=True, | |
| cross_attn_k=args.cross_attn_k, | |
| window=args.cross_attn_window_encoder, | |
| block_mask=True, | |
| ) | |
| local_encoder_embeds = compute_hash_embeddings( | |
| local_encoder_tokens=local_encoder_tokens, | |
| local_encoder=local_encoder, | |
| encoder_hash_tok_embedding=encoder_hash_tok_embedding, | |
| encoder_hash_byte_group_nb_functions=args.encoder_hash_byte_group_nb_functions, | |
| encoder_hash_byte_group_size=args.encoder_hash_byte_group_size, | |
| encoder_hash_byte_group_vocab=args.encoder_hash_byte_group_vocab, | |
| ) | |
| (h_encoder, h_cross), cache_encoder = local_encoder( | |
| tokens=local_encoder_tokens, | |
| embeds=local_encoder_embeds, | |
| patch_embeds=None, | |
| cross_mask=cross_attn_mask_enc, | |
| num_patches=patch_lengths.shape[1], | |
| patch_ids=patch_ids, | |
| ) | |
| assert h_encoder is not None | |
| assert h_cross is not None | |
| assert cache_encoder is None | |
| expected_shape = ( | |
| local_encoder_tokens.shape[0], | |
| local_encoder_tokens.shape[1], | |
| local_encoder.dim, | |
| ) | |
| assert h_encoder.shape == expected_shape | |
| assert h_cross.shape == (2, 2048, local_encoder.dim) | |
| def test_local_decoder_cross_attention(self): | |
| args = create_args(cross_attention=True) | |
| device = torch.device("cuda") | |
| local_decoder = create_local_decoder(args).to(device) | |
| test_files = { | |
| "dec_embeds": "dec_embeds.pt", | |
| "decoder_tokens": "local_decoder_tokens.pt", | |
| "patch_embeds": "decoder_patch_cross_embeds.pt", | |
| } | |
| batch = fake_batch() | |
| _, _, _, patch_lengths, _ = batch_to_tensors_and_gpu(batch) | |
| tensors = { | |
| name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) | |
| for name, filename in test_files.items() | |
| } | |
| decoder_patch_ids = decoder_patch_ids_from_lengths( | |
| patch_lengths, 0, tensors["decoder_tokens"].shape[-1] | |
| ) | |
| cross_attn_mask_dec = cross_attn_mask( | |
| decoder_patch_ids, | |
| patch_lengths, | |
| tensors["decoder_tokens"].shape[-1], | |
| patches_as_queries=False, | |
| cross_attn_k=args.cross_attn_k, | |
| window=args.cross_attn_window_decoder, | |
| block_mask=True, | |
| ) | |
| output, _ = local_decoder( | |
| embeds=tensors["dec_embeds"], | |
| patch_embeds=tensors["patch_embeds"], | |
| tokens=tensors["decoder_tokens"], | |
| cross_mask=cross_attn_mask_dec, | |
| cache=None, | |
| ) | |
| assert output is not None | |
| assert output.shape == (2, tensors["decoder_tokens"].shape[1], args.vocab_size) | |
| def test_local_decoder(self): | |
| args = create_args() | |
| device = torch.device("cuda") | |
| local_decoder = create_local_decoder(args).to(device) | |
| test_files = { | |
| "dec_embeds": "dec_embeds.pt", | |
| "decoder_tokens": "local_decoder_tokens.pt", | |
| "patch_embeds": "decoder_patch_embeds.pt", | |
| } | |
| tensors = { | |
| name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) | |
| for name, filename in test_files.items() | |
| } | |
| output, cache_decoder = local_decoder( | |
| embeds=tensors["dec_embeds"], | |
| patch_embeds=tensors["patch_embeds"], | |
| tokens=tensors["decoder_tokens"], | |
| cross_mask=None, | |
| cache=None, | |
| ) | |
| assert output is not None | |
| expected_shape = ( | |
| tensors["decoder_tokens"].shape[0], | |
| tensors["decoder_tokens"].shape[1], | |
| args.vocab_size, | |
| ) | |
| assert output.shape == expected_shape | |
| assert cache_decoder is None | |
| def test_global_transformer(self): | |
| args = create_args() | |
| device = torch.device("cuda") | |
| global_transformer = create_global_transformer(args).to(device) | |
| test_files = { | |
| "global_embeds": "global_embeds.pt", | |
| "global_tokens": "global_tokens.pt", | |
| } | |
| tensors = { | |
| name: torch.load(os.path.join(BLT_DATA, filename)).float().to(device) | |
| for name, filename in test_files.items() | |
| } | |
| h, cache = global_transformer( | |
| embeds=tensors["global_embeds"], tokens=tensors["global_tokens"] | |
| ) | |
| h is not None | |
| assert h.shape == (2, 256, 512) | |
| assert cache is None | |
| def test_blt_transformer_init(self): | |
| args = create_args() | |
| model = ByteLatentTransformer(args) | |
| assert model is not None | |
| def test_blt_transformer_forward(self, attn_impl): | |
| args = create_args() | |
| if attn_impl == "sdpa": | |
| os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "1" | |
| else: | |
| os.environ["BLT_SUPPRESS_ATTN_ERROR"] = "0" | |
| args = args.model_copy(update=dict(attn_impl=attn_impl)) | |
| model = ByteLatentTransformer(args) | |
| model = model.cuda() | |
| batch = fake_batch() | |
| x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) | |
| output = model( | |
| tokens=x, | |
| patch_lengths=patch_lengths, | |
| ngram_ids=ngram_ids, | |
| ) | |
| assert output is not None | |
| expected_shape = ( | |
| x.shape[0], | |
| x.shape[1], | |
| args.vocab_size, | |
| ) | |
| assert output.shape == expected_shape | |
| def test_blt_transformer_cross_attn_forward(self): | |
| args = create_args(cross_attention=True) | |
| model = ByteLatentTransformer(args) | |
| model = model.cuda() | |
| batch = fake_batch() | |
| x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) | |
| output = model( | |
| tokens=x, | |
| patch_lengths=patch_lengths, | |
| ngram_ids=ngram_ids, | |
| ) | |
| assert output is not None | |
| expected_shape = ( | |
| x.shape[0], | |
| x.shape[1], | |
| args.vocab_size, | |
| ) | |
| assert output.shape == expected_shape | |
| def test_cross_attention_rand(self): | |
| x = torch.randn(2, 256, 512, device="cuda") | |
| kv = torch.randn(2, 256, 512, device="cuda") | |
| cross_attention = CrossAttention( | |
| dim=512, | |
| head_dim=64, | |
| n_heads=8, | |
| n_kv_heads=4, | |
| norm_eps=1e-6, | |
| ).to("cuda") | |
| mask = create_causal_mask( | |
| x.shape[1], "flex_attention", None, sliding_window=None | |
| ) | |
| output = cross_attention(x, kv, mask) | |
| assert output is not None | |
| assert output.shape == (2, 256, 512) | |
| def test_ngram_embeddings(self): | |
| ngram_to_size = { | |
| 2: 38396, | |
| 3: 50000, | |
| 4: 50000, | |
| 5: 50000, | |
| 6: 50000, | |
| 7: 50000, | |
| 8: 50000, | |
| } | |
| batch = fake_batch() | |
| ngram_processor = NgramProcessor(BLT_DATA, ngram_to_size) | |
| ngram_ids = ngram_processor.encode_token_ngrams(batch.x) | |
| ngram_ids = np.stack(ngram_ids, axis=0) | |
| batch = replace(batch, ngram_ids=ngram_ids) | |
| args = create_args(cross_attention=True) | |
| args = args.model_copy( | |
| update=dict( | |
| encoder_ngram_to_size_str="2:38396,3:50000,4:50000,5:50000,6:50000,7:50000,8:50000", | |
| encoder_enable_byte_ngrams=True, | |
| ngram_vocab_sizes=ngram_processor.ngram_vocab_sizes, | |
| ) | |
| ) | |
| model = ByteLatentTransformer(args) | |
| model = model.cuda() | |
| x, _, _, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) | |
| output = model( | |
| tokens=x, | |
| patch_lengths=patch_lengths, | |
| ngram_ids=ngram_ids, | |
| ) | |
| assert output is not None | |
| expected_shape = ( | |
| x.shape[0], | |
| x.shape[1], | |
| args.vocab_size, | |
| ) | |
| assert output.shape == expected_shape | |
| def test_loss_backward(self): | |
| args = create_args() | |
| args = args.model_copy(update=dict(attn_impl="xformers")) | |
| batch = fake_batch() | |
| model = ByteLatentTransformer(args) | |
| steps = 10 | |
| optimizer, scheduler = build_optimizer(model, OptimArgs(lr=4e-04), steps) | |
| model = model.cuda() | |
| x, y, mask, patch_lengths, ngram_ids = batch_to_tensors_and_gpu(batch) | |
| initial_loss = None | |
| final_loss = None | |
| for step in range(steps): | |
| output = model( | |
| tokens=x, | |
| patch_lengths=patch_lengths, | |
| ngram_ids=ngram_ids, | |
| ) | |
| loss, _ = compute_loss(output, y, mask, 1.0) | |
| if step == 0: | |
| initial_loss = loss.item() | |
| if step == steps - 1: | |
| final_loss = loss.item() | |
| prev_loss = loss.item() | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| assert ( | |
| final_loss < initial_loss | |
| ), f"Training did not reduce loss: initial {initial_loss}, final {final_loss}" | |