davidhd commited on
Commit
3fe920c
·
verified ·
1 Parent(s): f8b02da

Add an assert about the attention_mask dtype.

Browse files

The attention mask in AMPLIFY should be additive, as a consequence, masked positions should have -inf values and unmasked positions should have 0 (or any constant value, but 0 is standard).

Files changed (1) hide show
  1. amplify.py +4 -0
amplify.py CHANGED
@@ -246,6 +246,10 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
246
 
247
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
248
  if attention_mask is not None and not torch.all(attention_mask == 0):
 
 
 
 
249
  attention_mask = (
250
  attention_mask.unsqueeze(1)
251
  .unsqueeze(1)
 
246
 
247
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
248
  if attention_mask is not None and not torch.all(attention_mask == 0):
249
+ assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, (
250
+ "AMPLIFY expects an additive attention_mask.\n"
251
+ "Modify the output of the tokenizer with attention_mask = torch.where(attention_mask, float(0.0), float("-inf"))"
252
+ )
253
  attention_mask = (
254
  attention_mask.unsqueeze(1)
255
  .unsqueeze(1)