alexnasa commited on
Commit
e7ad2ae
·
verified ·
1 Parent(s): 5f7fd6a

Update src/flux/transformer.py

Browse files
Files changed (1) hide show
  1. src/flux/transformer.py +4 -2
src/flux/transformer.py CHANGED
@@ -22,7 +22,6 @@ from diffusers.models.transformers.transformer_flux import (
22
  FluxTransformer2DModel,
23
  Transformer2DModelOutput,
24
  USE_PEFT_BACKEND,
25
- is_torch_version,
26
  scale_lora_layers,
27
  unscale_lora_layers,
28
  logger,
@@ -34,7 +33,10 @@ import torch
34
  import torch.nn as nn
35
  import torch.nn.functional as F
36
 
37
-
 
 
 
38
  def prepare_params(
39
  hidden_states: torch.Tensor,
40
  encoder_hidden_states: torch.Tensor = None,
 
22
  FluxTransformer2DModel,
23
  Transformer2DModelOutput,
24
  USE_PEFT_BACKEND,
 
25
  scale_lora_layers,
26
  unscale_lora_layers,
27
  logger,
 
33
  import torch.nn as nn
34
  import torch.nn.functional as F
35
 
36
+ def is_torch_version(spec: str) -> bool:
37
+ # e.g. spec = ">=1.12.0"
38
+ return version.parse(torch.__version__) in version.SpecifierSet(spec)
39
+
40
  def prepare_params(
41
  hidden_states: torch.Tensor,
42
  encoder_hidden_states: torch.Tensor = None,