Spaces:
Running
Running
File size: 4,271 Bytes
506a2b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from typing import Optional
import torch
from torch import nn as nn
from transformers import LlamaConfig, LlamaModel, LlamaPreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
class T3HuggingfaceBackend(LlamaPreTrainedModel, GenerationMixin):
"""
Override some HuggingFace interface methods so we can use the standard `generate` method with our
custom embedding / logit layers.
NOTE: need to extend "*PreTrainedModel" to avoid re-initializing weights!
"""
def __init__(
self,
config: LlamaConfig,
llama: LlamaModel,
*,
speech_enc,
speech_head,
latents_queue=None,
logits_queue=None,
alignment_stream_analyzer: 'AlignmentStreamAnalyzer'=None,
):
super().__init__(config)
self.model = llama
self.speech_enc = speech_enc
self.speech_head = speech_head
self._added_cond = False
self.alignment_stream_analyzer = alignment_stream_analyzer
@torch.inference_mode()
def prepare_inputs_for_generation(
self, input_ids: torch.Tensor, decoder_cond: torch.Tensor, use_cache: bool, past_key_values=None,
# This argument was introduced in some recent version of transformers (>=4.29.1)
cache_position=None
):
"""
This is a method used by huggingface's generate() method.
Overridden here to apply our custom speech token embedding layer.
:param input_ids: (B, S) int64 tensors of input tokens.
:param decoder_cond: (B, T, C) float32 tensor of conditioning (prefixed to <input_embeds>)
"""
# Make use of the kv cache: only the last input ID is new, we trim away all the ones before
if not use_cache:
past_key_values = None
if past_key_values is not None:
input_ids = input_ids[:, -1:]
# custom speech token embedding layer
inputs_embeds = self.speech_enc(input_ids)
# prefix decoder conditioning if applicable
if not self._added_cond:
assert past_key_values is not None # should be first step
if decoder_cond.size(0) != inputs_embeds.size(0):
decoder_cond = decoder_cond.expand(inputs_embeds.size(0), -1, -1)
inputs_embeds = torch.cat([decoder_cond, inputs_embeds], dim=1)
self._added_cond = True
return {
"inputs_embeds": inputs_embeds,
"past_key_values": past_key_values,
"use_cache": use_cache,
}
@torch.inference_mode()
def forward(
self,
inputs_embeds: torch.Tensor,
past_key_values: Optional[torch.Tensor]=None,
use_cache=True,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
):
"""
This is a method used by huggingface's generate() method.
Overridden here to apply our custom layer norm and speech logit projection layers.
:param inputs_embeds: (B, S, C) float32 tensor of conditioning inputs. If past key values are given,
S should be 1.
"""
is_large_input = inputs_embeds.size(1) != 1
has_cache = past_key_values is not None and len(past_key_values) > 0
assert not (is_large_input and has_cache)
assert return_dict
assert output_hidden_states
tfmr_out = self.model(
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
hidden_states = tfmr_out.hidden_states[-1] # (B, seq, dim)
logits = self.speech_head(hidden_states)
# assert inputs_embeds.size(0) == 1 # (disabled for CFG)
# NOTE: hallucination handler may modify logits to force emit an EOS token
# logits = self.alignment_stream_analyzer.step(logits)
return CausalLMOutputWithCrossAttentions(
logits=logits,
past_key_values=tfmr_out.past_key_values,
hidden_states=tfmr_out.hidden_states,
attentions=tfmr_out.attentions,
)
|