import torch from transformers import GPT2Model, GPT2Config from transformers.modeling_utils import PreTrainedModel, Conv1D, prune_conv1d_layer, SequenceSummary from transformers.models.gpt2.modeling_gpt2 import ( GPT2Block, GPT2Attention, GPT2MLP ) from torch import nn class Cond_Attention(GPT2Attention): def __init__(self, nx, n_ctx, config, is_cross_attention=False): super(GPT2Attention, self).__init__() self.output_attentions = config.output_attentions n_state = nx assert n_state % config.n_head == 0 self.embed_dim = config.n_embd self.num_heads = config.n_head self.head_dim = self.embed_dim // self.num_heads self.split_size = n_state self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention self.c_attn = Conv1D(n_state * 3, nx) self.c_proj = Conv1D(n_state, nx) self.attn_dropout = nn.Dropout(config.attn_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop) self.pruned_heads = set() self.c_z = Conv1D(n_state * 2, nx) def _attn(self, query, key, value, attention_mask=None, head_mask=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) if self.scale_attn_weights: attn_weights = attn_weights / torch.full( [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device ) attn_weights = nn.functional.softmax(attn_weights, dim=-1) attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) if head_mask is not None: attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights def forward(self, x, z, layer_past=None, attention_mask=None, head_mask=None, use_cache=True, output_attentions=False): x = self.c_attn(x) query, key, value = x.split(self.split_size, dim=2) query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) if layer_past is not None: past_key, past_value = layer_past key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) if use_cache: present = (key, value) else: present = None z_conv = self.c_z(z) key_z, value_z = z_conv.split(self.split_size, dim=2) key_z = self._split_heads(key_z, self.num_heads, self.head_dim) value_z = self._split_heads(value_z, self.num_heads, self.head_dim) key = key_z value = value_z attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) outputs = (attn_output, present) if output_attentions: outputs += (attn_weights,) return outputs class Cond_Block(GPT2Block): def __init__(self, config,activate_a = False,activate_v = False): super(GPT2Block, self).__init__() self.activate_a = activate_a self.activate_v = activate_v nx = config.n_embd self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.attn = Cond_Attention(nx,config.n_ctx,config) self.attn_a =None if not self.activate_a else Cond_Attention(nx,config.n_ctx,config) self.ln_a = None if not self.activate_a else nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.attn_v =None if not self.activate_v else Cond_Attention(nx,config.n_ctx,config) self.ln_v = None if not self.activate_v else nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.mlp = GPT2MLP(4 * nx, config) def forward(self, x, a,v, layer_past=None, attention_mask=None, head_mask=None): residual = x x = self.ln_1(x) attn_outputs = self.attn( x=x, z=x ) attn_output = attn_outputs[0] # outputs = attn_outputs[1:] x = x + attn_output if self.activate_a: x = self.ln_a(x) cross_attn_outputs = self.attn_a( x=x, z=a ) cross_attn_output = cross_attn_outputs[0] x = x + cross_attn_output if self.activate_v: x = self.ln_v(x) cross_attn_outputs = self.attn_v( x=x, z=v ) cross_attn_output = cross_attn_outputs[0] x = x + cross_attn_output m = self.mlp(self.ln_2(x)) x = x + m outputs = (x,) return outputs class EmotionInjectionTransformer(GPT2Model): def __init__(self, config, final_out_type="Linear+LN",sd_feature_dim=2048): super(GPT2Model, self).__init__(config) self.add_attn = True self.sd_feature_dim = sd_feature_dim self.activate_a = True self.activate_v = True self.output_hidden_states = config.output_hidden_states self.output_attentions = config.output_attentions self.use_cache = config.use_cache self.embed_dim = config.n_embd self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.n_positions, self.embed_dim) self.drop = nn.Dropout(config.embd_pdrop) self.xl_feature2gpt_feature = nn.Linear(self.sd_feature_dim,config.n_embd,bias=False) self.gpt_feature2xl_feature = nn.Linear(config.n_embd,self.sd_feature_dim,bias=False) if final_out_type == "Linear+LN" or final_out_type=="Linear+LN+noResidual": self.ln_xl_feature = nn.LayerNorm(self.sd_feature_dim, eps=1e-5) elif final_out_type == "Linear+LN+Linear" or final_out_type=="Linear+LN+Linear+noResidual": self.ln_xl_feature = nn.LayerNorm(self.sd_feature_dim, eps=1e-5) self.ff = nn.Linear(self.sd_feature_dim,self.sd_feature_dim,bias=False) else: raise NotImplementedError self.init_weights() self.cross_token = 16 self.a_f = nn.Sequential( nn.Linear(1, 256), nn.ReLU(), nn.Linear(256, config.n_embd*self.cross_token if self.activate_a else config.n_embd) ) self.v_f = nn.Sequential( nn.Linear(1, 256), nn.ReLU(), nn.Linear(256, config.n_embd*self.cross_token if self.activate_v else config.n_embd) ) if self.add_attn: self.attn_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) self.h = nn.ModuleList([Cond_Block(config,self.activate_a,self.activate_v) for _ in range(config.n_layer)]) else: self.h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layer)]) self.final_out_type = final_out_type self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( self, input_ids=None, past_key_values=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, arousal=None, valence=None, ): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size = inputs_embeds.shape[0] else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) if position_ids is not None: position_ids = position_ids.view(-1, input_shape[-1]) if past_key_values is None: past_length = 0 past_key_values = [None] * len(self.h) else: past_length = past_key_values[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) head_mask = self.get_head_mask(head_mask, self.config.n_layer) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) else: residual = inputs_embeds inputs_embeds = self.xl_feature2gpt_feature(inputs_embeds) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds hidden_states = self.drop(hidden_states) a_feature = self.attn_proj(self.a_f(arousal).view(-1, self.cross_token, self.config.n_embd) ) v_feature = self.attn_proj(self.v_f(valence).view(-1, self.cross_token, self.config.n_embd) ) output_shape = input_shape + (hidden_states.size(-1),) all_self_attentions = () if self.output_attentions else None all_hidden_states = () if self.output_hidden_states else None for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = block( hidden_states, a = a_feature,v = v_feature, layer_past=layer_past, attention_mask=attention_mask, head_mask=head_mask[i] ) hidden_states = outputs[0] if self.output_attentions: all_self_attentions = all_self_attentions + (outputs[2 if self.use_cache else 1],) hidden_states = self.ln_f(hidden_states) if self.final_out_type == "Linear+LN": hidden_states = residual+self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states)) elif self.final_out_type == "Linear+LN+noResidual": hidden_states = self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states)) elif self.final_out_type == "Linear+LN+Linear": hidden_states = residual+self.ff(self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states))) elif self.final_out_type == "Linear+LN+Linear+noResidual": hidden_states = self.ff(self.ln_xl_feature(self.gpt_feature2xl_feature(hidden_states))) elif self.final_out_type == "Linear+noResidual": hidden_states = self.gpt_feature2xl_feature(hidden_states) else: hidden_states = residual+self.gpt_feature2xl_feature(hidden_states) if self.output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) outputs = (hidden_states,) if self.output_hidden_states: outputs = outputs + (all_hidden_states,) if self.output_attentions: attention_output_shape = input_shape[:-1] + (-1,) + all_self_attentions[0].shape[-2:] all_attentions = tuple(t.view(*attention_output_shape) for t in all_self_attentions) outputs = outputs + (all_attentions,) return outputs