fengyutong commited on
Commit
052fb73
·
1 Parent(s): e4df51f

fix pipeline load

Browse files
Files changed (1) hide show
  1. app.py +7 -18
app.py CHANGED
@@ -26,13 +26,13 @@ weight_dtype = torch.bfloat16
26
  args = OmegaConf.load('configs/omnitry_v1_unified.yaml')
27
 
28
  # init model
29
- transformer = FluxTransformer2DModel.from_pretrained('./FLUX.1-Fill-dev/transformer').requires_grad_(False).to(device, dtype=weight_dtype)
30
- vae = diffusers.AutoencoderKL.from_pretrained('./FLUX.1-Fill-dev/vae').requires_grad_(False).to(device, dtype=weight_dtype)
31
- text_encoder = transformers.CLIPTextModel.from_pretrained('./FLUX.1-Fill-dev/text_encoder').requires_grad_(False).to(device, dtype=weight_dtype)
32
- text_encoder_2 = transformers.T5EncoderModel.from_pretrained('./FLUX.1-Fill-dev/text_encoder_2').requires_grad_(False).to(device, dtype=weight_dtype)
33
- scheduler = diffusers.FlowMatchEulerDiscreteScheduler.from_pretrained('./FLUX.1-Fill-dev/scheduler')
34
- tokenizer = transformers.CLIPTokenizer.from_pretrained('./FLUX.1-Fill-dev/tokenizer')
35
- tokenizer_2 = transformers.T5TokenizerFast.from_pretrained('./FLUX.1-Fill-dev/tokenizer_2')
36
 
37
  # insert LoRA
38
  lora_config = LoraConfig(
@@ -81,17 +81,6 @@ for n, m in transformer.named_modules():
81
  if isinstance(m, peft.tuners.lora.layer.Linear):
82
  m.forward = create_hacked_forward(m)
83
 
84
- # init pipeline
85
- pipeline = FluxFillPipeline(
86
- transformer=transformer.eval(),
87
- scheduler=copy.deepcopy(scheduler),
88
- vae=vae,
89
- text_encoder=text_encoder,
90
- text_encoder_2=text_encoder_2,
91
- tokenizer=tokenizer,
92
- tokenizer_2=tokenizer_2,
93
- )
94
-
95
 
96
  def seed_everything(seed=0):
97
  random.seed(seed)
 
26
  args = OmegaConf.load('configs/omnitry_v1_unified.yaml')
27
 
28
  # init model
29
+ transformer = FluxTransformer2DModel.from_pretrained('black-forest-labs/FLUX.1-Fill-dev', subfolder='transformer').requires_grad_(False).to(device, dtype=weight_dtype)
30
+ pipeline = FluxFillPipeline.from_pretrained(
31
+ 'black-forest-labs/FLUX.1-Fill-dev',
32
+ transformer=transformer,
33
+ torch_dtype=weight_dtype
34
+ ).to(device)
35
+
36
 
37
  # insert LoRA
38
  lora_config = LoraConfig(
 
81
  if isinstance(m, peft.tuners.lora.layer.Linear):
82
  m.forward = create_hacked_forward(m)
83
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def seed_everything(seed=0):
86
  random.seed(seed)