Update src/pipeline.py
Browse files- src/pipeline.py +4 -4
src/pipeline.py
CHANGED
@@ -38,10 +38,8 @@ torch.backends.cudnn.enabled = True
|
|
38 |
|
39 |
# globals
|
40 |
Pipeline = None
|
41 |
-
ckpt_id = "
|
42 |
-
ckpt_revision = "
|
43 |
-
TinyVAE = "madebyollin/taef1"
|
44 |
-
TinyVAE_REV = "2d552378e58c9c94201075708d7de4e1163b2689"
|
45 |
|
46 |
def empty_cache():
|
47 |
gc.collect()
|
@@ -50,6 +48,8 @@ def empty_cache():
|
|
50 |
torch.cuda.reset_peak_memory_stats()
|
51 |
|
52 |
def load_pipeline() -> Pipeline:
|
|
|
|
|
53 |
text_encoder_2 = T5EncoderModel.from_pretrained("manbeast3b/flux.1-schnell-full1", revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", subfolder="text_encoder_2",torch_dtype=torch.bfloat16)
|
54 |
path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
|
55 |
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
|
|
|
38 |
|
39 |
# globals
|
40 |
Pipeline = None
|
41 |
+
ckpt_id = "manbeast3b/Flux.1.schnell-quant2"
|
42 |
+
ckpt_revision = "44eb293715147878512da10bf3bc47cd14ec8c55"
|
|
|
|
|
43 |
|
44 |
def empty_cache():
|
45 |
gc.collect()
|
|
|
48 |
torch.cuda.reset_peak_memory_stats()
|
49 |
|
50 |
def load_pipeline() -> Pipeline:
|
51 |
+
vae = AutoencoderKL.from_pretrained(ckpt_id,revision=ckpt_revision, subfolder="vae", local_files_only=True, torch_dtype=torch.bfloat16,)
|
52 |
+
quantize_(vae, int8_weight_only())
|
53 |
text_encoder_2 = T5EncoderModel.from_pretrained("manbeast3b/flux.1-schnell-full1", revision = "cb1b599b0d712b9aab2c4df3ad27b050a27ec146", subfolder="text_encoder_2",torch_dtype=torch.bfloat16)
|
54 |
path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--flux.1-schnell-full1/snapshots/cb1b599b0d712b9aab2c4df3ad27b050a27ec146/transformer")
|
55 |
transformer = FluxTransformer2DModel.from_pretrained(path, torch_dtype=torch.bfloat16, use_safetensors=False)
|