Spaces:
Running
on
Zero
Running
on
Zero
Update src/flux/pipeline_tools.py
Browse files- src/flux/pipeline_tools.py +11 -3
src/flux/pipeline_tools.py
CHANGED
@@ -510,10 +510,18 @@ class CustomFluxPipeline:
|
|
510 |
ckpt_root_condition=None,
|
511 |
torch_dtype=torch.bfloat16,
|
512 |
):
|
513 |
-
|
514 |
-
model_path = os.getenv("FLUX_MODEL_PATH", "diffusers/FLUX.1-dev-torchao-int8" if config["model"].get("dit_quant", "None")=="int8-quanto" else "black-forest-labs/FLUX.1-dev")
|
515 |
print("[CustomFluxPipeline] Loading FLUX Pipeline")
|
516 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
517 |
self.pipe.enable_sequential_cpu_offload()
|
518 |
|
519 |
self.config = config
|
|
|
510 |
ckpt_root_condition=None,
|
511 |
torch_dtype=torch.bfloat16,
|
512 |
):
|
513 |
+
|
|
|
514 |
print("[CustomFluxPipeline] Loading FLUX Pipeline")
|
515 |
+
if config["model"].get("dit_quant", "None")=="int8-quanto":
|
516 |
+
self.pipe = FluxPipeline.from_pretrained("diffusers/FLUX.1-dev-torchao-int8",
|
517 |
+
torch_dtype=torch_dtype,
|
518 |
+
use_safetensors=False,
|
519 |
+
device_map="balanced").to(device)
|
520 |
+
|
521 |
+
else:
|
522 |
+
self.pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev",
|
523 |
+
torch_dtype=torch_dtype).to(device)
|
524 |
+
|
525 |
self.pipe.enable_sequential_cpu_offload()
|
526 |
|
527 |
self.config = config
|