jiuface commited on
Commit
7ae7fc2
·
verified ·
1 Parent(s): 4e9b062

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -24,7 +24,8 @@ import random
24
  import string
25
  from diffusers import FluxPipeline
26
  from huggingface_hub import hf_hub_download
27
-
 
28
 
29
  # Login Hugging Face Hub
30
  HF_TOKEN = os.environ.get("HF_TOKEN")
@@ -35,15 +36,16 @@ import diffusers
35
  dtype = torch.bfloat16
36
  device = "cuda:0"
37
 
38
- print(device)
39
- #base_model = "black-forest-labs/FLUX.1-dev"
40
  base_model = "black-forest-labs/FLUX.1-Krea-dev"
41
- # load pipe
42
 
43
- txt2img_pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=dtype)
 
 
 
 
44
 
 
45
  txt2img_pipe = txt2img_pipe.to(device)
46
- #txt2img_pipe.__class__.load_lora_into_transformer = classmethod(load_lora_into_transformer)
47
 
48
  MAX_SEED = 2**32 - 1
49
 
@@ -157,14 +159,15 @@ def run_lora(prompt, image_url, lora_strings_json, image_strength, cfg_scale, s
157
  adapter_weights.append(adapter_weight)
158
  if lora_repo and weights and adapter_name:
159
  try:
160
- #txt2img_pipe.to(device)
161
  txt2img_pipe.load_lora_weights(hf_hub_download(lora_repo, weights), adapter_name=lora_name)
162
  except:
163
  print("load lora error")
 
164
  # set lora weights
165
- #if len(lora_names) > 0:
166
- #txt2img_pipe.to(device)
167
- #txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
 
168
 
169
  # Generate image
170
  error_message = ""
 
24
  import string
25
  from diffusers import FluxPipeline
26
  from huggingface_hub import hf_hub_download
27
+ from diffusers.quantizers import PipelineQuantizationConfig
28
+ from diffusers import (FluxPriorReduxPipeline, FluxInpaintPipeline, FluxFillPipeline, FluxKontextPipeline, FluxPipeline)
29
 
30
  # Login Hugging Face Hub
31
  HF_TOKEN = os.environ.get("HF_TOKEN")
 
36
  dtype = torch.bfloat16
37
  device = "cuda:0"
38
 
 
 
39
  base_model = "black-forest-labs/FLUX.1-Krea-dev"
 
40
 
41
+ pipeline_quant_config = PipelineQuantizationConfig(
42
+ quant_backend="bitsandbytes_4bit",
43
+ quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
44
+ components_to_quantize=["transformer", "text_encoder_2"],
45
+ )
46
 
47
+ txt2img_pipe = FluxKontextPipeline.from_pretrained(base_model, quantization_config=pipeline_quant_config, torch_dtype=dtype)
48
  txt2img_pipe = txt2img_pipe.to(device)
 
49
 
50
  MAX_SEED = 2**32 - 1
51
 
 
159
  adapter_weights.append(adapter_weight)
160
  if lora_repo and weights and adapter_name:
161
  try:
 
162
  txt2img_pipe.load_lora_weights(hf_hub_download(lora_repo, weights), adapter_name=lora_name)
163
  except:
164
  print("load lora error")
165
+
166
  # set lora weights
167
+ if len(lora_names) > 0:
168
+ txt2img_pipe.set_adapters(lora_names, adapter_weights=adapter_weights)
169
+ txt2img_pipe.fuse_lora(adapter_names=lora_names)
170
+ txt2img_pipe.enable_vae_slicing()
171
 
172
  # Generate image
173
  error_message = ""