helloworld-S commited on
Commit
a2d4626
·
verified ·
1 Parent(s): 2df105d

Update src/flux/pipeline_tools.py

Browse files
Files changed (1) hide show
  1. src/flux/pipeline_tools.py +11 -3
src/flux/pipeline_tools.py CHANGED
@@ -510,10 +510,18 @@ class CustomFluxPipeline:
510
  ckpt_root_condition=None,
511
  torch_dtype=torch.bfloat16,
512
  ):
513
-
514
- model_path = os.getenv("FLUX_MODEL_PATH", "diffusers/FLUX.1-dev-torchao-int8" if config["model"].get("dit_quant", "None")=="int8-quanto" else "black-forest-labs/FLUX.1-dev")
515
  print("[CustomFluxPipeline] Loading FLUX Pipeline")
516
- self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
 
 
 
 
 
 
 
 
 
517
  self.pipe.enable_sequential_cpu_offload()
518
 
519
  self.config = config
 
510
  ckpt_root_condition=None,
511
  torch_dtype=torch.bfloat16,
512
  ):
513
+
 
514
  print("[CustomFluxPipeline] Loading FLUX Pipeline")
515
+ if config["model"].get("dit_quant", "None")=="int8-quanto":
516
+ self.pipe = FluxPipeline.from_pretrained("diffusers/FLUX.1-dev-torchao-int8",
517
+ torch_dtype=torch_dtype,
518
+ use_safetensors=False,
519
+ device_map="balanced").to(device)
520
+
521
+ else:
522
+ self.pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
523
+ torch_dtype=torch_dtype).to(device)
524
+
525
  self.pipe.enable_sequential_cpu_offload()
526
 
527
  self.config = config