Keltezaa commited on
Commit
ac2913f
·
verified ·
1 Parent(s): 0f915f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -54
app.py CHANGED
@@ -113,45 +113,21 @@ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).
113
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
114
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
115
 
116
- # Function to dynamically merge models
117
- def merge_models(base_model, enhancement_model, alpha=0.7):
118
- for base_param, enhance_param in zip(base_model.parameters(), enhancement_model.parameters()):
119
- base_param.data = alpha * base_param.data + (1 - alpha) * enhance_param.data
120
- return base_model
121
 
122
  # Gradio interface function
123
- def process_image(enable_enhancement, weight_slider):
124
- # Load enhancement model if enabled
125
- if enable_enhancement:
126
- enhancement_model_path = "xey/sldr_flux_nsfw_v2-studio"
127
- try:
128
- enhancement_model = AutoencoderKL.from_pretrained(enhancement_model_path, torch_dtype=dtype).to(device)
129
- # Merge with the base VAE using the weight from the slider
130
- merged_vae = merge_models(good_vae, enhancement_model, alpha=weight_slider)
131
- except Exception as e:
132
- return f"Failed to load or merge enhancement model: {e}"
133
- else:
134
- merged_vae = good_vae # Use the base VAE if no enhancement is enabled
135
-
136
- # Create the image pipeline with the updated VAE
137
- pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
138
- base_model,
139
- vae=merged_vae,
140
- transformer=pipe.transformer,
141
- text_encoder=pipe.text_encoder,
142
- tokenizer=pipe.tokenizer,
143
- text_encoder_2=pipe.text_encoder_2,
144
- tokenizer_2=pipe.tokenizer_2,
145
- torch_dtype=dtype
146
- )
147
- generated_image = pipe_i2i(prompt="A generated image")
148
-
149
- if isinstance(generated_image, Image.Image):
150
- return generated_image # Gradio can handle PIL Image objects
151
- elif isinstance(generated_image, str): # If it returns a file path, return that
152
- return generated_image
153
- else:
154
- return None
155
 
156
  MAX_SEED = 2**32 - 1
157
 
@@ -652,22 +628,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
652
  with gr.Row():
653
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
654
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
655
- # Gradio UI Elements for Enhancement Model and Weight Slider
656
- with gr.Row():
657
- with gr.Column():
658
- enable_enhancement_checkbox = gr.Checkbox(
659
- label="Enable Enhancement Model",
660
- value=False,
661
- elem_id="enable_enhancement_checkbox"
662
- )
663
- enhancement_weight_slider = gr.Slider(
664
- label="Weight for Enhancement Model",
665
- minimum=0.0,
666
- maximum=1.0,
667
- step=0.05,
668
- value=0.75, # Default weight
669
- elem_id="enhancement_weight_slider"
670
- )
671
 
672
  gallery.select(
673
  update_selection,
@@ -702,7 +663,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
702
  gr.on(
703
  triggers=[generate_button.click, prompt.submit],
704
  fn=run_lora,
705
- inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, enable_enhancement_checkbox, enhancement_weight_slider, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
706
  outputs=[result, seed, progress_bar]
707
  ).then(
708
  fn=lambda x, history: update_history(x, history),
 
113
  good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
114
  pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
115
 
 
 
 
 
 
116
 
117
  # Gradio interface function
118
+ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
119
+ good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
120
+ pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
121
+ pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
122
+ base_model,
123
+ vae=good_vae,
124
+ transformer=pipe.transformer,
125
+ text_encoder=pipe.text_encoder,
126
+ tokenizer=pipe.tokenizer,
127
+ text_encoder_2=pipe.text_encoder_2,
128
+ tokenizer_2=pipe.tokenizer_2,
129
+ torch_dtype=dtype
130
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  MAX_SEED = 2**32 - 1
133
 
 
628
  with gr.Row():
629
  randomize_seed = gr.Checkbox(True, label="Randomize seed")
630
  seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
631
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632
 
633
  gallery.select(
634
  update_selection,
 
663
  gr.on(
664
  triggers=[generate_button.click, prompt.submit],
665
  fn=run_lora,
666
+ inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
667
  outputs=[result, seed, progress_bar]
668
  ).then(
669
  fn=lambda x, history: update_history(x, history),