Spaces:
Runtime error
Runtime error
| import sys | |
| sys.path.insert(0,'stable_diffusion') | |
| import gradio as gr | |
| from train_esd import train_esd | |
| from convertModels import convert_ldm_unet_checkpoint, create_unet_diffusers_config | |
| from omegaconf import OmegaConf | |
| from StableDiffuser import StableDiffuser | |
| from diffusers import UNet2DConditionModel | |
| ckpt_path = "stable_diffusion/models/ldm/sd-v1-4-full-ema.ckpt" | |
| config_path = "stable_diffusion/configs/stable-diffusion/v1-inference.yaml" | |
| diffusers_config_path = "stable_diffusion/config.json" | |
| class Demo: | |
| def __init__(self) -> None: | |
| with gr.Blocks() as demo: | |
| self.layout() | |
| demo.queue(concurrency_count=10).launch() | |
| def disable(self): | |
| return [gr.update(interactive=False), gr.update(interactive=False)] | |
| def layout(self): | |
| with gr.Row(): | |
| with gr.Column() as training_column: | |
| self.prompt_input = gr.Text( | |
| placeholder="Enter prompt...", | |
| label="Prompt", | |
| info="Prompt corresponding to concept to erase" | |
| ) | |
| self.train_method_input = gr.Dropdown( | |
| choices=['noxattn', 'selfattn', 'xattn', 'full'], | |
| value='xattn', | |
| label='Train Method', | |
| info='Method of training' | |
| ) | |
| self.neg_guidance_input = gr.Number( | |
| value=1, | |
| label="Negative Guidance", | |
| info='Guidance of negative training used to train' | |
| ) | |
| self.iterations_input = gr.Number( | |
| value=1000, | |
| precision=0, | |
| label="Iterations", | |
| info='iterations used to train' | |
| ) | |
| self.lr_input = gr.Number( | |
| value=1e-5, | |
| label="Learning Rate", | |
| info='Learning rate used to train' | |
| ) | |
| self.train_button = gr.Button( | |
| value="Train", | |
| ) | |
| with gr.Column() as inference_column: | |
| with gr.Row(): | |
| self.prompt_input_infr = gr.Text( | |
| placeholder="Enter prompt...", | |
| label="Prompt", | |
| info="Prompt corresponding to concept to erase" | |
| ) | |
| with gr.Row(): | |
| self.image_new = gr.Image( | |
| label="New Image", | |
| interactive=False | |
| ) | |
| self.image_orig = gr.Image( | |
| label="Orig Image", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| self.infr_button = gr.Button( | |
| value="Generate", | |
| interactive=False | |
| ) | |
| self.infr_button.click(self.inference, inputs = [ | |
| self.prompt_input_infr, | |
| ], | |
| outputs=[ | |
| self.image_new, | |
| self.image_orig | |
| ] | |
| ) | |
| self.train_button.click(self.disable, | |
| outputs=[self.train_button, self.infr_button] | |
| ) | |
| self.train_button.click(self.train, inputs = [ | |
| self.prompt_input, | |
| self.train_method_input, | |
| self.neg_guidance_input, | |
| self.iterations_input, | |
| self.lr_input | |
| ], | |
| outputs=[self.train_button, self.infr_button] | |
| ) | |
| def train(self, prompt, train_method, neg_guidance, iterations, lr, pbar = gr.Progress(track_tqdm=True)): | |
| model_orig, model_edited = train_esd(prompt, | |
| train_method, | |
| 3, | |
| neg_guidance, | |
| iterations, | |
| lr, | |
| config_path, | |
| ckpt_path, | |
| diffusers_config_path, | |
| ['cuda', 'cuda'] | |
| ) | |
| original_config = OmegaConf.load(config_path) | |
| original_config["model"]["params"]["unet_config"]["params"]["in_channels"] = 4 | |
| unet_config = create_unet_diffusers_config(original_config, image_size=512) | |
| model_edited_sd = convert_ldm_unet_checkpoint(model_edited.state_dict(), unet_config) | |
| model_orig_sd = convert_ldm_unet_checkpoint(model_orig.state_dict(), unet_config) | |
| self.init_inference(model_edited_sd, model_orig_sd, unet_config) | |
| return [gr.update(interactive=True), gr.update(interactive=True)] | |
| def init_inference(self, model_edited_sd, model_orig_sd, unet_config): | |
| self.model_edited_sd = model_edited_sd | |
| self.model_orig_sd = model_orig_sd | |
| self.diffuser = StableDiffuser(42) | |
| self.diffuser.unet = UNet2DConditionModel(**unet_config) | |
| self.diffuser.to('cuda') | |
| def inference(self, prompt): | |
| self.diffuser.unet.load_state_dict(self.model_orig_sd) | |
| images = self.diffuser( | |
| prompt, | |
| n_steps=50, | |
| reseed=True | |
| ) | |
| orig_image = images[0][0] | |
| self.diffuser.unet.load_state_dict(self.model_edited_sd) | |
| images = self.diffuser( | |
| prompt, | |
| n_steps=50, | |
| reseed=True | |
| ) | |
| edited_image = images[0][0] | |
| return edited_image, orig_image | |
| demo = Demo() | |