seawolf2357 commited on
Commit
71e9891
·
verified ·
1 Parent(s): a8c3250

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -13
app.py CHANGED
@@ -27,12 +27,13 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
27
  repo_id = "black-forest-labs/FLUX.1-dev"
28
  adapter_id = "seawolf2357/nsfw-detection" # Changed to Renoir model
29
 
30
- # Initialize pipeline with PEFT support
31
  print("Loading pipeline...")
32
- # Add 'use_peft=True' parameter to enable PEFT support
33
- pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16, use_peft=True)
34
  print("Loading LoRA weights...")
35
- pipeline.load_lora_weights(adapter_id)
 
36
  pipeline = pipeline.to(device)
37
 
38
  MAX_SEED = np.iinfo(np.int32).max
@@ -111,15 +112,30 @@ def inference(
111
  seed = random.randint(0, MAX_SEED)
112
  generator = torch.Generator(device=device).manual_seed(seed)
113
 
114
- image = pipeline(
115
- prompt=processed_prompt,
116
- guidance_scale=guidance_scale,
117
- num_inference_steps=num_inference_steps,
118
- width=width,
119
- height=height,
120
- generator=generator,
121
- joint_attention_kwargs={"scale": lora_scale},
122
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
  # Save the generated image
125
  filepath = save_generated_image(image, processed_prompt)
 
27
  repo_id = "black-forest-labs/FLUX.1-dev"
28
  adapter_id = "seawolf2357/nsfw-detection" # Changed to Renoir model
29
 
30
+ # Initialize pipeline
31
  print("Loading pipeline...")
32
+ # Use DiffusionPipeline instead of FluxPipeline to ensure proper LoRA compatibility
33
+ pipeline = DiffusionPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16)
34
  print("Loading LoRA weights...")
35
+ # Add low_cpu_mem_usage=False to avoid PEFT assign=True incompatibility issue
36
+ pipeline.load_lora_weights(adapter_id, low_cpu_mem_usage=False)
37
  pipeline = pipeline.to(device)
38
 
39
  MAX_SEED = np.iinfo(np.int32).max
 
112
  seed = random.randint(0, MAX_SEED)
113
  generator = torch.Generator(device=device).manual_seed(seed)
114
 
115
+ # Use joint_attention_kwargs to control LoRA scale
116
+ # (FluxPipeline may use a different parameter name but attempt both)
117
+ try:
118
+ image = pipeline(
119
+ prompt=processed_prompt,
120
+ guidance_scale=guidance_scale,
121
+ num_inference_steps=num_inference_steps,
122
+ width=width,
123
+ height=height,
124
+ generator=generator,
125
+ joint_attention_kwargs={"scale": lora_scale},
126
+ ).images[0]
127
+ except Exception as e:
128
+ # If the above fails, try with cross_attention_kwargs which is more common
129
+ print(f"First attempt failed with: {e}, trying alternative method...")
130
+ image = pipeline(
131
+ prompt=processed_prompt,
132
+ guidance_scale=guidance_scale,
133
+ num_inference_steps=num_inference_steps,
134
+ width=width,
135
+ height=height,
136
+ generator=generator,
137
+ cross_attention_kwargs={"scale": lora_scale},
138
+ ).images[0]
139
 
140
  # Save the generated image
141
  filepath = save_generated_image(image, processed_prompt)