"""Modified from https://github.com/mlfoundations/open_flamingo""" import random import torch.nn as nn from .helpers import GatedCrossAttentionBlock from .utils import getattr_recursive, setattr_recursive class FlamingoLayer(nn.Module): def __init__(self, gated_cross_attn_layer, decoder_layer): super().__init__() self.gated_cross_attn_layer = gated_cross_attn_layer self.decoder_layer = decoder_layer self.vis_x = None self.media_locations = None self.only_lang_x = False def is_conditioned(self) -> bool: """Check whether the layer is conditioned.""" return self.vis_x is not None # Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/) def condition_vis_x(self, vis_x): self.vis_x = vis_x def condition_only_lang_x(self, only_lang_x=False): self.only_lang_x = only_lang_x def condition_media_locations(self, media_locations): self.media_locations = media_locations def condition_attend_previous(self, attend_previous): self.attend_previous = attend_previous def forward( self, lang_x, attention_mask=None, **decoder_layer_kwargs, ): if self.gated_cross_attn_layer is None or self.only_lang_x: return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) if self.vis_x is None: raise ValueError("vis_x must be conditioned before forward pass") if self.media_locations is None: raise ValueError("media_locations must be conditioned before forward pass") lang_x = self.gated_cross_attn_layer( lang_x, self.vis_x, media_locations=self.media_locations, attend_previous=self.attend_previous, ) lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs) return lang_x class FlamingoLMMixin(nn.Module): """ Mixin to add cross-attention layers to a language model. """ def set_decoder_layers_attr_name(self, decoder_layers_attr_name): self.decoder_layers_attr_name = decoder_layers_attr_name def _get_decoder_layers(self): return getattr_recursive(self, self.decoder_layers_attr_name) def _set_decoder_layers(self, value): setattr_recursive(self, self.decoder_layers_attr_name, value) def init_flamingo( self, media_token_id, vis_hidden_size, cross_attn_every_n_layers, use_media_placement_augmentation, ): """ Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations. """ self.gated_cross_attn_layers = nn.ModuleList( [ GatedCrossAttentionBlock(dim=self.config.hidden_size, dim_visual=vis_hidden_size) if (layer_idx + 1) % cross_attn_every_n_layers == 0 else None for layer_idx, _ in enumerate(self._get_decoder_layers()) ] ) self._set_decoder_layers( nn.ModuleList( [ FlamingoLayer(gated_cross_attn_layer, decoder_layer) for gated_cross_attn_layer, decoder_layer in zip( self.gated_cross_attn_layers, self._get_decoder_layers() ) ] ) ) self.media_token_id = media_token_id self.use_media_placement_augmentation = use_media_placement_augmentation self.initialized_flamingo = True def forward(self, *input, **kwargs): """Condition the Flamingo layers on the media locations before forward()""" if not self.initialized_flamingo: raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.") input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0] media_locations = input_ids == self.media_token_id attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else False for layer in self.get_decoder().layers: layer.condition_media_locations(media_locations) layer.condition_attend_previous(attend_previous) return super().forward(*input, **kwargs) # Call the other parent's forward method def is_conditioned(self) -> bool: """Check whether all decoder layers are already conditioned.""" return all(l.is_conditioned() for l in self._get_decoder_layers()) def clear_conditioned_layers(self): for layer in self._get_decoder_layers(): layer.condition_vis_x(None) layer.condition_media_locations(None) layer.condition_attend_previous(None)