Update app.py
Browse files
app.py
CHANGED
@@ -113,45 +113,21 @@ taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).
|
|
113 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
114 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
115 |
|
116 |
-
# Function to dynamically merge models
|
117 |
-
def merge_models(base_model, enhancement_model, alpha=0.7):
|
118 |
-
for base_param, enhance_param in zip(base_model.parameters(), enhancement_model.parameters()):
|
119 |
-
base_param.data = alpha * base_param.data + (1 - alpha) * enhance_param.data
|
120 |
-
return base_model
|
121 |
|
122 |
# Gradio interface function
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
# Create the image pipeline with the updated VAE
|
137 |
-
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
138 |
-
base_model,
|
139 |
-
vae=merged_vae,
|
140 |
-
transformer=pipe.transformer,
|
141 |
-
text_encoder=pipe.text_encoder,
|
142 |
-
tokenizer=pipe.tokenizer,
|
143 |
-
text_encoder_2=pipe.text_encoder_2,
|
144 |
-
tokenizer_2=pipe.tokenizer_2,
|
145 |
-
torch_dtype=dtype
|
146 |
-
)
|
147 |
-
generated_image = pipe_i2i(prompt="A generated image")
|
148 |
-
|
149 |
-
if isinstance(generated_image, Image.Image):
|
150 |
-
return generated_image # Gradio can handle PIL Image objects
|
151 |
-
elif isinstance(generated_image, str): # If it returns a file path, return that
|
152 |
-
return generated_image
|
153 |
-
else:
|
154 |
-
return None
|
155 |
|
156 |
MAX_SEED = 2**32 - 1
|
157 |
|
@@ -652,22 +628,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
|
|
652 |
with gr.Row():
|
653 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
654 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
|
655 |
-
|
656 |
-
with gr.Row():
|
657 |
-
with gr.Column():
|
658 |
-
enable_enhancement_checkbox = gr.Checkbox(
|
659 |
-
label="Enable Enhancement Model",
|
660 |
-
value=False,
|
661 |
-
elem_id="enable_enhancement_checkbox"
|
662 |
-
)
|
663 |
-
enhancement_weight_slider = gr.Slider(
|
664 |
-
label="Weight for Enhancement Model",
|
665 |
-
minimum=0.0,
|
666 |
-
maximum=1.0,
|
667 |
-
step=0.05,
|
668 |
-
value=0.75, # Default weight
|
669 |
-
elem_id="enhancement_weight_slider"
|
670 |
-
)
|
671 |
|
672 |
gallery.select(
|
673 |
update_selection,
|
@@ -702,7 +663,7 @@ with gr.Blocks(css=css, delete_cache=(60, 60)) as app:
|
|
702 |
gr.on(
|
703 |
triggers=[generate_button.click, prompt.submit],
|
704 |
fn=run_lora,
|
705 |
-
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices,
|
706 |
outputs=[result, seed, progress_bar]
|
707 |
).then(
|
708 |
fn=lambda x, history: update_history(x, history),
|
|
|
113 |
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
114 |
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
116 |
|
117 |
# Gradio interface function
|
118 |
+
taef1 = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype).to(device)
|
119 |
+
good_vae = AutoencoderKL.from_pretrained(base_model, subfolder="vae", torch_dtype=dtype).to(device)
|
120 |
+
pipe = DiffusionPipeline.from_pretrained(base_model, torch_dtype=dtype, vae=taef1).to(device)
|
121 |
+
pipe_i2i = AutoPipelineForImage2Image.from_pretrained(
|
122 |
+
base_model,
|
123 |
+
vae=good_vae,
|
124 |
+
transformer=pipe.transformer,
|
125 |
+
text_encoder=pipe.text_encoder,
|
126 |
+
tokenizer=pipe.tokenizer,
|
127 |
+
text_encoder_2=pipe.text_encoder_2,
|
128 |
+
tokenizer_2=pipe.tokenizer_2,
|
129 |
+
torch_dtype=dtype
|
130 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
MAX_SEED = 2**32 - 1
|
133 |
|
|
|
628 |
with gr.Row():
|
629 |
randomize_seed = gr.Checkbox(True, label="Randomize seed")
|
630 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
|
631 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
gallery.select(
|
634 |
update_selection,
|
|
|
663 |
gr.on(
|
664 |
triggers=[generate_button.click, prompt.submit],
|
665 |
fn=run_lora,
|
666 |
+
inputs=[prompt, input_image, image_strength, cfg_scale, steps, selected_indices, lora_scale_1, lora_scale_2, randomize_seed, seed, width, height, loras_state],
|
667 |
outputs=[result, seed, progress_bar]
|
668 |
).then(
|
669 |
fn=lambda x, history: update_history(x, history),
|