Update src/flux/transformer.py
Browse files- src/flux/transformer.py +4 -2
src/flux/transformer.py
CHANGED
@@ -22,7 +22,6 @@ from diffusers.models.transformers.transformer_flux import (
|
|
22 |
FluxTransformer2DModel,
|
23 |
Transformer2DModelOutput,
|
24 |
USE_PEFT_BACKEND,
|
25 |
-
is_torch_version,
|
26 |
scale_lora_layers,
|
27 |
unscale_lora_layers,
|
28 |
logger,
|
@@ -34,7 +33,10 @@ import torch
|
|
34 |
import torch.nn as nn
|
35 |
import torch.nn.functional as F
|
36 |
|
37 |
-
|
|
|
|
|
|
|
38 |
def prepare_params(
|
39 |
hidden_states: torch.Tensor,
|
40 |
encoder_hidden_states: torch.Tensor = None,
|
|
|
22 |
FluxTransformer2DModel,
|
23 |
Transformer2DModelOutput,
|
24 |
USE_PEFT_BACKEND,
|
|
|
25 |
scale_lora_layers,
|
26 |
unscale_lora_layers,
|
27 |
logger,
|
|
|
33 |
import torch.nn as nn
|
34 |
import torch.nn.functional as F
|
35 |
|
36 |
+
def is_torch_version(spec: str) -> bool:
|
37 |
+
# e.g. spec = ">=1.12.0"
|
38 |
+
return version.parse(torch.__version__) in version.SpecifierSet(spec)
|
39 |
+
|
40 |
def prepare_params(
|
41 |
hidden_states: torch.Tensor,
|
42 |
encoder_hidden_states: torch.Tensor = None,
|