Add an assert about the attention_mask dtype.
Browse filesThe attention mask in AMPLIFY should be additive, as a consequence, masked positions should have -inf values and unmasked positions should have 0 (or any constant value, but 0 is standard).
- amplify.py +4 -0
amplify.py
CHANGED
@@ -246,6 +246,10 @@ class AMPLIFY(AMPLIFYPreTrainedModel):
|
|
246 |
|
247 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
248 |
if attention_mask is not None and not torch.all(attention_mask == 0):
|
|
|
|
|
|
|
|
|
249 |
attention_mask = (
|
250 |
attention_mask.unsqueeze(1)
|
251 |
.unsqueeze(1)
|
|
|
246 |
|
247 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
248 |
if attention_mask is not None and not torch.all(attention_mask == 0):
|
249 |
+
assert attention_mask.dtype != torch.bool and 1.0 not in attention_mask, (
|
250 |
+
"AMPLIFY expects an additive attention_mask.\n"
|
251 |
+
"Modify the output of the tokenizer with attention_mask = torch.where(attention_mask, float(0.0), float("-inf"))"
|
252 |
+
)
|
253 |
attention_mask = (
|
254 |
attention_mask.unsqueeze(1)
|
255 |
.unsqueeze(1)
|