Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -23,10 +23,11 @@ text_encoder = CLIPTextModel.from_pretrained(model_name, subfolder="text_encoder | |
| 23 | 
             
            unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
         | 
| 24 | 
             
            scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
         | 
| 25 |  | 
| 26 | 
            -
            def load_examples_from_directory(sample_output_dir=" | 
| 27 | 
             
                """
         | 
| 28 | 
            -
                Load example data from the  | 
| 29 | 
             
                Assumes each image has a corresponding .json file with metadata.
         | 
|  | |
| 30 | 
             
                """
         | 
| 31 | 
             
                examples = []
         | 
| 32 | 
             
                json_files = glob.glob(os.path.join(sample_output_dir, "*.json"))
         | 
| @@ -43,14 +44,15 @@ def load_examples_from_directory(sample_output_dir="sample_output"): | |
| 43 | 
             
                                metadata["width"],
         | 
| 44 | 
             
                                metadata["num_inference_steps"],
         | 
| 45 | 
             
                                metadata["guidance_scale"],
         | 
| 46 | 
            -
                                metadata["seed"]
         | 
|  | |
| 47 | 
             
                            ])
         | 
| 48 | 
             
                    except Exception as e:
         | 
| 49 | 
             
                        print(f"Error loading {json_file}: {e}")
         | 
| 50 |  | 
| 51 | 
             
                if not examples:
         | 
| 52 | 
             
                    examples = [
         | 
| 53 | 
            -
                        ["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42]
         | 
| 54 | 
             
                    ]
         | 
| 55 | 
             
                return examples
         | 
| 56 |  | 
| @@ -122,25 +124,60 @@ def generate_image(prompt, height, width, num_inference_steps, guidance_scale, s | |
| 122 |  | 
| 123 | 
             
                return pil_image, f"Image generated successfully! Seed used: {seed}"
         | 
| 124 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 125 | 
             
            def on_example_select(selected_row):
         | 
| 126 | 
             
                """
         | 
| 127 | 
             
                Handle selection of a row in the examples Dataframe.
         | 
| 128 | 
            -
                 | 
| 129 | 
             
                """
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 130 | 
             
                return (
         | 
| 131 | 
            -
                     | 
| 132 | 
            -
                    selected_row[1],  # height
         | 
| 133 | 
            -
                    selected_row[2],  # width
         | 
| 134 | 
            -
                    selected_row[3],  # num_inference_steps
         | 
| 135 | 
            -
                    selected_row[4],  # guidance_scale
         | 
| 136 | 
            -
                    selected_row[5],  # seed
         | 
| 137 | 
            -
                    False            # random_seed
         | 
| 138 | 
             
                )
         | 
| 139 |  | 
| 140 | 
             
            # Gradio interface
         | 
| 141 | 
             
            with gr.Blocks() as demo:
         | 
| 142 | 
             
                gr.Markdown("# Ghibli-Style Image Generator")
         | 
| 143 | 
            -
                gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter a prompt or select an example below to  | 
| 144 | 
             
                gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).")
         | 
| 145 |  | 
| 146 | 
             
                with gr.Row():
         | 
| @@ -158,10 +195,10 @@ with gr.Blocks() as demo: | |
| 158 | 
             
                        output_text = gr.Textbox(label="Status")
         | 
| 159 |  | 
| 160 | 
             
                gr.Markdown("### Example Prompts")
         | 
| 161 | 
            -
                examples_data = load_examples_from_directory(" | 
| 162 | 
             
                examples = gr.Dataframe(
         | 
| 163 | 
             
                    value=examples_data,
         | 
| 164 | 
            -
                    headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed"],
         | 
| 165 | 
             
                    label="Examples",
         | 
| 166 | 
             
                    interactive=True
         | 
| 167 | 
             
                )
         | 
| @@ -170,11 +207,7 @@ with gr.Blocks() as demo: | |
| 170 | 
             
                examples.select(
         | 
| 171 | 
             
                    fn=on_example_select,
         | 
| 172 | 
             
                    inputs=examples,
         | 
| 173 | 
            -
                    outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed]
         | 
| 174 | 
            -
                ).then(
         | 
| 175 | 
            -
                    fn=generate_image,
         | 
| 176 | 
            -
                    inputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed],
         | 
| 177 | 
            -
                    outputs=[output_image, output_text]
         | 
| 178 | 
             
                )
         | 
| 179 |  | 
| 180 | 
             
                generate_btn.click(
         | 
|  | |
| 23 | 
             
            unet = UNet2DConditionModel.from_pretrained(model_name, subfolder="unet", torch_dtype=dtype).to(device)
         | 
| 24 | 
             
            scheduler = PNDMScheduler.from_pretrained(model_name, subfolder="scheduler")
         | 
| 25 |  | 
| 26 | 
            +
            def load_examples_from_directory(sample_output_dir="assets/example"):
         | 
| 27 | 
             
                """
         | 
| 28 | 
            +
                Load example data from the assets/example directory.
         | 
| 29 | 
             
                Assumes each image has a corresponding .json file with metadata.
         | 
| 30 | 
            +
                Returns a list of [prompt, height, width, num_inference_steps, guidance_scale, seed, json_filename].
         | 
| 31 | 
             
                """
         | 
| 32 | 
             
                examples = []
         | 
| 33 | 
             
                json_files = glob.glob(os.path.join(sample_output_dir, "*.json"))
         | 
|  | |
| 44 | 
             
                                metadata["width"],
         | 
| 45 | 
             
                                metadata["num_inference_steps"],
         | 
| 46 | 
             
                                metadata["guidance_scale"],
         | 
| 47 | 
            +
                                metadata["seed"],
         | 
| 48 | 
            +
                                os.path.basename(json_file)  # Store the JSON filename for image lookup
         | 
| 49 | 
             
                            ])
         | 
| 50 | 
             
                    except Exception as e:
         | 
| 51 | 
             
                        print(f"Error loading {json_file}: {e}")
         | 
| 52 |  | 
| 53 | 
             
                if not examples:
         | 
| 54 | 
             
                    examples = [
         | 
| 55 | 
            +
                        ["a serene landscape in Ghibli style", 64, 64, 50, 3.5, 42, None]
         | 
| 56 | 
             
                    ]
         | 
| 57 | 
             
                return examples
         | 
| 58 |  | 
|  | |
| 124 |  | 
| 125 | 
             
                return pil_image, f"Image generated successfully! Seed used: {seed}"
         | 
| 126 |  | 
| 127 | 
            +
            def load_existing_image(json_filename, sample_output_dir="assets/example"):
         | 
| 128 | 
            +
                """
         | 
| 129 | 
            +
                Load an existing image corresponding to the given JSON filename.
         | 
| 130 | 
            +
                Assumes the image has the same base name as the JSON file (e.g., example1.json -> example1.png).
         | 
| 131 | 
            +
                """
         | 
| 132 | 
            +
                if not json_filename:
         | 
| 133 | 
            +
                    return None, "No associated image found."
         | 
| 134 | 
            +
                
         | 
| 135 | 
            +
                # Get the base name without extension
         | 
| 136 | 
            +
                base_name = os.path.splitext(json_filename)[0]
         | 
| 137 | 
            +
                # Look for common image extensions
         | 
| 138 | 
            +
                image_extensions = ["*.png", "*.jpg", "*.jpeg"]
         | 
| 139 | 
            +
                image_path = None
         | 
| 140 | 
            +
                
         | 
| 141 | 
            +
                for ext in image_extensions:
         | 
| 142 | 
            +
                    candidates = glob.glob(os.path.join(sample_output_dir, f"{base_name}{ext}"))
         | 
| 143 | 
            +
                    if candidates:
         | 
| 144 | 
            +
                        image_path = candidates[0]
         | 
| 145 | 
            +
                        break
         | 
| 146 | 
            +
                
         | 
| 147 | 
            +
                if not image_path:
         | 
| 148 | 
            +
                    return None, f"No image found for {json_filename} in {sample_output_dir}."
         | 
| 149 | 
            +
                
         | 
| 150 | 
            +
                try:
         | 
| 151 | 
            +
                    pil_image = Image.open(image_path)
         | 
| 152 | 
            +
                    return pil_image, f"Loaded existing image for {json_filename}."
         | 
| 153 | 
            +
                except Exception as e:
         | 
| 154 | 
            +
                    return None, f"Error loading image {image_path}: {e}"
         | 
| 155 | 
            +
             | 
| 156 | 
             
            def on_example_select(selected_row):
         | 
| 157 | 
             
                """
         | 
| 158 | 
             
                Handle selection of a row in the examples Dataframe.
         | 
| 159 | 
            +
                Loads the existing image and updates input fields.
         | 
| 160 | 
             
                """
         | 
| 161 | 
            +
                prompt = selected_row[0]
         | 
| 162 | 
            +
                height = selected_row[1]
         | 
| 163 | 
            +
                width = selected_row[2]
         | 
| 164 | 
            +
                num_inference_steps = selected_row[3]
         | 
| 165 | 
            +
                guidance_scale = selected_row[4]
         | 
| 166 | 
            +
                seed = selected_row[5]
         | 
| 167 | 
            +
                json_filename = selected_row[6]  # JSON filename for image lookup
         | 
| 168 | 
            +
                random_seed = False
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                # Load the existing image
         | 
| 171 | 
            +
                image, status = load_existing_image(json_filename)
         | 
| 172 | 
            +
                
         | 
| 173 | 
             
                return (
         | 
| 174 | 
            +
                    prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, image, status
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 175 | 
             
                )
         | 
| 176 |  | 
| 177 | 
             
            # Gradio interface
         | 
| 178 | 
             
            with gr.Blocks() as demo:
         | 
| 179 | 
             
                gr.Markdown("# Ghibli-Style Image Generator")
         | 
| 180 | 
            +
                gr.Markdown("Generate images in Ghibli style using a fine-tuned Stable Diffusion model. Enter a prompt or select an example below to load a pre-generated image.")
         | 
| 181 | 
             
                gr.Markdown("**Note:** For CPU inference, execution time is long (e.g., for 64x64 resolution with 50 inference steps, time is approximately 1800 seconds).")
         | 
| 182 |  | 
| 183 | 
             
                with gr.Row():
         | 
|  | |
| 195 | 
             
                        output_text = gr.Textbox(label="Status")
         | 
| 196 |  | 
| 197 | 
             
                gr.Markdown("### Example Prompts")
         | 
| 198 | 
            +
                examples_data = load_examples_from_directory("assets/example")
         | 
| 199 | 
             
                examples = gr.Dataframe(
         | 
| 200 | 
             
                    value=examples_data,
         | 
| 201 | 
            +
                    headers=["Prompt", "Height", "Width", "Inference Steps", "Guidance Scale", "Seed", "JSON File"],
         | 
| 202 | 
             
                    label="Examples",
         | 
| 203 | 
             
                    interactive=True
         | 
| 204 | 
             
                )
         | 
|  | |
| 207 | 
             
                examples.select(
         | 
| 208 | 
             
                    fn=on_example_select,
         | 
| 209 | 
             
                    inputs=examples,
         | 
| 210 | 
            +
                    outputs=[prompt, height, width, num_inference_steps, guidance_scale, seed, random_seed, output_image, output_text]
         | 
|  | |
|  | |
|  | |
|  | |
| 211 | 
             
                )
         | 
| 212 |  | 
| 213 | 
             
                generate_btn.click(
         | 
