Spaces:
Runtime error
Runtime error
add partition helpers
Browse files- dalle_mini/partitions.py +69 -0
dalle_mini/partitions.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from flax.core.frozen_dict import freeze
|
| 4 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
| 5 |
+
from jax.experimental import PartitionSpec as P
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# utils adapted from https://gitihub.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
|
| 9 |
+
# Sentinels
|
| 10 |
+
_unmatched = object()
|
| 11 |
+
|
| 12 |
+
# For specifying empty leaf dict `{}`
|
| 13 |
+
empty_dict = object()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _match(qs, ks):
|
| 17 |
+
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
| 18 |
+
# compile regexes and force complete match
|
| 19 |
+
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
| 20 |
+
for i in range(len(ks) - len(qs) + 1):
|
| 21 |
+
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
|
| 22 |
+
if matches and all(matches):
|
| 23 |
+
return True
|
| 24 |
+
return False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _replacement_rules(rules):
|
| 28 |
+
def replace(key, val):
|
| 29 |
+
for rule, replacement in rules:
|
| 30 |
+
if _match(rule, key):
|
| 31 |
+
return replacement
|
| 32 |
+
return val
|
| 33 |
+
|
| 34 |
+
return replace
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _get_partition_rules():
|
| 38 |
+
return [
|
| 39 |
+
# embeddings
|
| 40 |
+
((r"embed_positions", "embedding"), P("mp", None)),
|
| 41 |
+
((r"embed_tokens", "embedding"), P("mp", None)),
|
| 42 |
+
# self-attention
|
| 43 |
+
((r"self_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
| 44 |
+
((r"self_attn", "out_proj", "kernel"), P("mp", None)),
|
| 45 |
+
# enc-dec attention
|
| 46 |
+
((r"encoder_attn", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
| 47 |
+
((r"encoder_attn", "out_proj", "kernel"), P("mp", None)),
|
| 48 |
+
# FFN
|
| 49 |
+
((r"fc1", "kernel"), P(None, "mp")),
|
| 50 |
+
((r"fc2", "kernel"), P("mp", None)),
|
| 51 |
+
# layer norms
|
| 52 |
+
((r"layernorm_embedding", "(bias|scale)"), None),
|
| 53 |
+
((r"self_attn_layer_norm", "(bias|scale)"), None),
|
| 54 |
+
((r"encoder_attn_layer_norm", "(bias|scale)"), None),
|
| 55 |
+
((r"final_layer_norm", "(bias|scale)"), None),
|
| 56 |
+
((r"lm_head", "kernel"), P(None, "mp")),
|
| 57 |
+
]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def set_partitions(in_dict):
|
| 61 |
+
rules = _get_partition_rules()
|
| 62 |
+
replace = _replacement_rules(rules)
|
| 63 |
+
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
| 64 |
+
result = {k: replace(k, v) for k, v in initd.items()}
|
| 65 |
+
for k, v in result.items():
|
| 66 |
+
if v == _unmatched:
|
| 67 |
+
print(k)
|
| 68 |
+
assert _unmatched not in result.values(), "Incomplete partition spec."
|
| 69 |
+
return freeze(unflatten_dict(result))
|