Spaces:
Running
on
Zero
Running
on
Zero
Update src/flux/pipeline_tools.py
Browse files
src/flux/pipeline_tools.py
CHANGED
@@ -510,7 +510,8 @@ class CustomFluxPipeline:
|
|
510 |
ckpt_root_condition=None,
|
511 |
torch_dtype=torch.bfloat16,
|
512 |
):
|
513 |
-
|
|
|
514 |
print("[CustomFluxPipeline] Loading FLUX Pipeline")
|
515 |
self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
|
516 |
self.pipe.enable_sequential_cpu_offload()
|
@@ -518,8 +519,8 @@ class CustomFluxPipeline:
|
|
518 |
self.config = config
|
519 |
self.device = device
|
520 |
self.dtype = torch_dtype
|
521 |
-
if config["model"].get("dit_quant", "None") != "None":
|
522 |
-
|
523 |
|
524 |
self.modulation_adapters = []
|
525 |
self.pipe.modulation_adapters = []
|
|
|
510 |
ckpt_root_condition=None,
|
511 |
torch_dtype=torch.bfloat16,
|
512 |
):
|
513 |
+
|
514 |
+
model_path = os.getenv("FLUX_MODEL_PATH", "diffusers/FLUX.1-dev-torchao-int8" if config["model"].get("dit_quant", "None")=="int8-quanto" else "black-forest-labs/FLUX.1-dev")
|
515 |
print("[CustomFluxPipeline] Loading FLUX Pipeline")
|
516 |
self.pipe = FluxPipeline.from_pretrained(model_path, torch_dtype=torch_dtype).to(device)
|
517 |
self.pipe.enable_sequential_cpu_offload()
|
|
|
519 |
self.config = config
|
520 |
self.device = device
|
521 |
self.dtype = torch_dtype
|
522 |
+
# if config["model"].get("dit_quant", "None") != "None":
|
523 |
+
# quantization(self.pipe, config["model"]["dit_quant"])
|
524 |
|
525 |
self.modulation_adapters = []
|
526 |
self.pipe.modulation_adapters = []
|