Spaces:
Running
Running
| import gradio as gr | |
| from diffusion_lens import get_images | |
| import numpy as np | |
| MAX_SEED = np.iinfo(np.int32).max | |
| # Description | |
| title = r""" | |
| <h1 align="center">Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</h1> | |
| """ | |
| description = r""" | |
| <b>A demo for the paper <a href='https://arxiv.org/abs/2403.05846' target='_blank'>Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines</a>.<br> | |
| """ | |
| article = r""" | |
| --- | |
| π **Citation** | |
| <br> | |
| If our work is helpful for your research or applications, please cite us via: | |
| ```bibtex | |
| @article{toker2024diffusion, | |
| title={Diffusion Lens: Interpreting Text Encoders in Text-to-Image Pipelines}, | |
| author={Toker, Michael and Orgad, Hadas and Ventura, Mor and Arad, Dana and Belinkov, Yonatan}, | |
| journal={arXiv preprint arXiv:2403.05846}, | |
| year={2024} | |
| } | |
| ``` | |
| π§ **Contact** | |
| <br> | |
| If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>. | |
| """ | |
| model_num_of_layers = { | |
| 'Stable Diffusion 1.4': 12, | |
| 'Stable Diffusion 2.1': 22, | |
| } | |
| def generate_images(prompt, model, seed): | |
| seed = random.randint(0, MAX_SEED) if seed == -1 else seed | |
| print('calling diffusion lens with model:', model, 'and seed:', seed) | |
| gr.Info('Generating images from intermediate layers..') | |
| all_images = [] # Initialize a list to store all images | |
| max_num_of_layers = model_num_of_layers[model] | |
| for skip_layers in range(max_num_of_layers - 1, -1, -1): | |
| # Pass the model and seed to the get_images function | |
| images = get_images(prompt, skip_layers=skip_layers, model=model, seed=seed) | |
| all_images.append((images[0], f'layer_{12 - skip_layers}')) | |
| yield all_images | |
| with gr.Blocks() as demo: | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| # text_input = gr.Textbox(label="Enter prompt") | |
| # model_select = gr.Dropdown(label="Select Model", choices=['sd1', 'sd2']) | |
| # seed_input = gr.Number(label="Enter Seed", value=0) # Default seed set to 0 | |
| # Update the submit function to include the new inputs | |
| # text_input.submit(fn=generate_images, inputs=[text_input, model_select, seed_input], outputs=gallery) | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="A photo of Steve Jobs", | |
| ) | |
| model = gr.Radio( | |
| [ | |
| "Stable Diffusion 1.4", | |
| "Stable Diffusion 2.1", | |
| ], | |
| value="Stable Diffusion 1.4", | |
| label="Model", | |
| ) | |
| seed = gr.Slider( | |
| minimum=-1, | |
| maximum=MAX_SEED, | |
| value=-1, | |
| step=1, | |
| label="Seed Value", | |
| ) | |
| inputs = [ | |
| prompt, | |
| model, | |
| seed, | |
| ] | |
| outputs = [gallery] | |
| generate_button = gr.Button("Generate Image") | |
| with gr.Column(): | |
| gallery = gr.Gallery(label="Generated Images", columns=4, rows=3, object_fit="contain", height="auto") | |
| gr.on( | |
| triggers=[ | |
| # prompt.submit, | |
| generate_button.click, | |
| # seed.input, | |
| # model.input | |
| ], | |
| fn=generate_images, | |
| inputs=inputs, | |
| outputs=outputs, | |
| show_progress="full", | |
| show_api=False, | |
| trigger_mode="always_last", | |
| ) | |
| gr.Markdown(article) | |
| demo.queue(api_open=False) | |
| demo.launch(show_api=False) | |