Spaces:
Running
on
Zero
Running
on
Zero
Update pipelines/flux_pipeline/transformer.py
Browse files
pipelines/flux_pipeline/transformer.py
CHANGED
|
@@ -41,19 +41,6 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph
|
|
| 41 |
|
| 42 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
|
| 44 |
-
def log_scale_masking(value, min_value=1, max_value=10):
|
| 45 |
-
# Convert the value into a positive domain for the logarithmic function
|
| 46 |
-
normalized_value = 1*value
|
| 47 |
-
|
| 48 |
-
# Apply logarithmic scaling
|
| 49 |
-
# log_scaled_value = 1-np.exp(-normalized_value)
|
| 50 |
-
log_scaled_value = 2.0* math.log(normalized_value+1, 2) / math.log(2, 2) # np.log1p(x) = log(1 + x)
|
| 51 |
-
# print(log_scaled_value)
|
| 52 |
-
|
| 53 |
-
# Rescale to original range
|
| 54 |
-
scaled_value = log_scaled_value * (max_value - min_value) + min_value
|
| 55 |
-
|
| 56 |
-
return min(max_value, int(scaled_value))
|
| 57 |
|
| 58 |
class FluxAttnProcessor2_0:
|
| 59 |
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
|
@@ -137,7 +124,7 @@ class FluxAttnProcessor2_0:
|
|
| 137 |
if neg_mode:
|
| 138 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
| 139 |
hw = res*res
|
| 140 |
-
mask_ = torch.
|
| 141 |
for i in range(num):
|
| 142 |
mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
|
| 143 |
mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")
|
|
|
|
| 41 |
|
| 42 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
class FluxAttnProcessor2_0:
|
| 46 |
"""Attention processor used typically in processing the SD3-like self-attention projections."""
|
|
|
|
| 124 |
if neg_mode:
|
| 125 |
res = int(math.sqrt((end_of_hidden_states-(text_seq if encoder_hidden_states is None else 0)) // num))
|
| 126 |
hw = res*res
|
| 127 |
+
mask_ = torch.zeros(1, res, num*res, res, num*res).to(query.device)
|
| 128 |
for i in range(num):
|
| 129 |
mask_[:, :, i*res:(i+1)*res, :, i*res:(i+1)*res] = 1
|
| 130 |
mask_ = rearrange(mask_, "b h w h1 w1 -> b (h w) (h1 w1)")
|