Spaces:
Runtime error
Runtime error
File size: 5,008 Bytes
03561be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
"""Modified from https://github.com/mlfoundations/open_flamingo"""
import random
import torch.nn as nn
from .helpers import GatedCrossAttentionBlock
from .utils import getattr_recursive, setattr_recursive
class FlamingoLayer(nn.Module):
def __init__(self, gated_cross_attn_layer, decoder_layer):
super().__init__()
self.gated_cross_attn_layer = gated_cross_attn_layer
self.decoder_layer = decoder_layer
self.vis_x = None
self.media_locations = None
self.only_lang_x = False
def is_conditioned(self) -> bool:
"""Check whether the layer is conditioned."""
return self.vis_x is not None
# Used this great idea from this implementation of Flamingo (https://github.com/dhansmair/flamingo-mini/)
def condition_vis_x(self, vis_x):
self.vis_x = vis_x
def condition_only_lang_x(self, only_lang_x=False):
self.only_lang_x = only_lang_x
def condition_media_locations(self, media_locations):
self.media_locations = media_locations
def condition_attend_previous(self, attend_previous):
self.attend_previous = attend_previous
def forward(
self,
lang_x,
attention_mask=None,
**decoder_layer_kwargs,
):
if self.gated_cross_attn_layer is None or self.only_lang_x:
return self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
if self.vis_x is None:
raise ValueError("vis_x must be conditioned before forward pass")
if self.media_locations is None:
raise ValueError("media_locations must be conditioned before forward pass")
lang_x = self.gated_cross_attn_layer(
lang_x,
self.vis_x,
media_locations=self.media_locations,
attend_previous=self.attend_previous,
)
lang_x = self.decoder_layer(lang_x, attention_mask=attention_mask, **decoder_layer_kwargs)
return lang_x
class FlamingoLMMixin(nn.Module):
"""
Mixin to add cross-attention layers to a language model.
"""
def set_decoder_layers_attr_name(self, decoder_layers_attr_name):
self.decoder_layers_attr_name = decoder_layers_attr_name
def _get_decoder_layers(self):
return getattr_recursive(self, self.decoder_layers_attr_name)
def _set_decoder_layers(self, value):
setattr_recursive(self, self.decoder_layers_attr_name, value)
def init_flamingo(
self,
media_token_id,
vis_hidden_size,
cross_attn_every_n_layers,
use_media_placement_augmentation,
):
"""
Initialize Flamingo by adding a new gated cross attn to the decoder. Store the media token id for computing the media locations.
"""
self.gated_cross_attn_layers = nn.ModuleList(
[
GatedCrossAttentionBlock(dim=self.config.hidden_size, dim_visual=vis_hidden_size)
if (layer_idx + 1) % cross_attn_every_n_layers == 0
else None
for layer_idx, _ in enumerate(self._get_decoder_layers())
]
)
self._set_decoder_layers(
nn.ModuleList(
[
FlamingoLayer(gated_cross_attn_layer, decoder_layer)
for gated_cross_attn_layer, decoder_layer in zip(
self.gated_cross_attn_layers, self._get_decoder_layers()
)
]
)
)
self.media_token_id = media_token_id
self.use_media_placement_augmentation = use_media_placement_augmentation
self.initialized_flamingo = True
def forward(self, *input, **kwargs):
"""Condition the Flamingo layers on the media locations before forward()"""
if not self.initialized_flamingo:
raise ValueError("Flamingo layers are not initialized. Please call `init_flamingo` first.")
input_ids = kwargs["input_ids"] if "input_ids" in kwargs else input[0]
media_locations = input_ids == self.media_token_id
attend_previous = (random.random() < 0.5) if self.use_media_placement_augmentation else False
for layer in self.get_decoder().layers:
layer.condition_media_locations(media_locations)
layer.condition_attend_previous(attend_previous)
return super().forward(*input, **kwargs) # Call the other parent's forward method
def is_conditioned(self) -> bool:
"""Check whether all decoder layers are already conditioned."""
return all(l.is_conditioned() for l in self._get_decoder_layers())
def clear_conditioned_layers(self):
for layer in self._get_decoder_layers():
layer.condition_vis_x(None)
layer.condition_media_locations(None)
layer.condition_attend_previous(None)
|