devroop's picture
Upload folder using huggingface_hub
03561be verified
"""Modified from https://github.com/mlfoundations/open_flamingo"""
import random
import torch.nn as nn
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
class FlamingoLayer(nn.Module):
def __init__(self, gated_cross_attn_layer, decoder_layer):
super().__init__()
self.gated_cross_attn_layer = gated_cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
self.media_locations = None
self.only_lang_x = False
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def condition_only_lang_x(self, only_lang_x=False):
self.only_lang_x = only_lang_x
def condition_media_locations(self, media_locations):
self.media_locations = media_locations
def condition_attend_previous(self, attend_previous):
self.attend_previous = attend_previous
def forward(
self,
lang_x,
attention_mask=None,
**decoder_layer_kwargs,
):
if self.gated_cross_attn_layer is None or self.only_lang_x:
return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
if self.media_locations is None:
raise ValueError("media_locations must be conditioned before forward pass")
lang_x = self.gated_cross_attn_layer(
lang_x,
self.vis_x,
media_locations=self.media_locations,
attend_previous=self.attend_previous,
)
lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
return lang_x
class FlamingoLMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_flamingo(
self,
media_token_id,
vis_hidden_size,
cross_attn_every_n_layers,
use_media_placement_augmentation,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
"""
self.gated_cross_attn_layers = nn.ModuleList(
[
GatedCrossAttentionBlock(dim=self.config.hidden_size, dim_visual=vis_hidden_size)
if (layer_idx + 1) % cross_attn_every_n_layers == 0
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self._set_decoder_layers(
nn.ModuleList(
[
FlamingoLayer(gated_cross_attn_layer, decoder_layer)
for gated_cross_attn_layer, decoder_layer in zip(
self.gated_cross_attn_layers, self._get_decoder_layers()
)
]
)
)
self.media_token_id = media_token_id
self.use_media_placement_augmentation = use_media_placement_augmentation
self.initialized_flamingo = True
def forward(self, *input, **kwargs):
"""Condition the Flamingo layers on the media locations before forward()"""
if not self.initialized_flamingo:
raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
media_locations = input_ids == self.media_token_id
attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else False
for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
return super().forward(*input, **kwargs) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_attend_previous(None)