helloworld-S commited on
Commit
1bb30b7
·
verified ·
1 Parent(s): 6be12ae

Update src/flux/pipeline_tools.py

Browse files
Files changed (1) hide show
  1. src/flux/pipeline_tools.py +4 -3
src/flux/pipeline_tools.py CHANGED
@@ -510,7 +510,8 @@ class CustomFluxPipeline:
510
  ckpt_root_condition=None,
511
  torch_dtype=torch.bfloat16,
512
  ):
513
- model_path = os.getenv("FLUX_MODEL_PATH", "black-forest-labs/FLUX.1-dev")
 
514
  print("[CustomFluxPipeline] Loading FLUX Pipeline")
515
  self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
516
  self.pipe.enable_sequential_cpu_offload()
@@ -518,8 +519,8 @@ class CustomFluxPipeline:
518
  self.config = config
519
  self.device = device
520
  self.dtype = torch_dtype
521
- if config["model"].get("dit_quant", "None") != "None":
522
- quantization(self.pipe, config["model"]["dit_quant"])
523
 
524
  self.modulation_adapters = []
525
  self.pipe.modulation_adapters = []
 
510
  ckpt_root_condition=None,
511
  torch_dtype=torch.bfloat16,
512
  ):
513
+
514
+ model_path = os.getenv("FLUX_MODEL_PATH", "diffusers/FLUX.1-dev-torchao-int8" if config["model"].get("dit_quant", "None")=="int8-quanto" else "black-forest-labs/FLUX.1-dev")
515
  print("[CustomFluxPipeline] Loading FLUX Pipeline")
516
  self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
517
  self.pipe.enable_sequential_cpu_offload()
 
519
  self.config = config
520
  self.device = device
521
  self.dtype = torch_dtype
522
+ # if config["model"].get("dit_quant", "None") != "None":
523
+ # quantization(self.pipe, config["model"]["dit_quant"])
524
 
525
  self.modulation_adapters = []
526
  self.pipe.modulation_adapters = []