Update src/flux/transformer.py
Browse files- src/flux/transformer.py +4 -2
src/flux/transformer.py
CHANGED
@@ -22,7 +22,6 @@ from diffusers.models.transformers.transformer_flux import (
|
|
22 |
FluxTransformer2DModel,
|
23 |
Transformer2DModelOutput,
|
24 |
USE_PEFT_BACKEND,
|
25 |
-
is_torch_version,
|
26 |
scale_lora_layers,
|
27 |
unscale_lora_layers,
|
28 |
logger,
|
@@ -63,7 +62,10 @@ def prepare_params(
|
|
63 |
return_dict,
|
64 |
)
|
65 |
|
66 |
-
|
|
|
|
|
|
|
67 |
def tranformer_forward(
|
68 |
transformer: FluxTransformer2DModel,
|
69 |
condition_latents: torch.Tensor,
|
|
|
22 |
FluxTransformer2DModel,
|
23 |
Transformer2DModelOutput,
|
24 |
USE_PEFT_BACKEND,
|
|
|
25 |
scale_lora_layers,
|
26 |
unscale_lora_layers,
|
27 |
logger,
|
|
|
62 |
return_dict,
|
63 |
)
|
64 |
|
65 |
+
def is_torch_version(spec: str) -> bool:
|
66 |
+
# e.g. spec = ">=1.12.0"
|
67 |
+
return version.parse(torch.__version__) in version.SpecifierSet(spec)
|
68 |
+
|
69 |
def tranformer_forward(
|
70 |
transformer: FluxTransformer2DModel,
|
71 |
condition_latents: torch.Tensor,
|