Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	modified: gradio_app.py
Browse files- gradio_app.py +60 -31
    	
        gradio_app.py
    CHANGED
    
    | @@ -94,33 +94,35 @@ def load_target_model(selected_model): | |
| 94 | 
             
                AE_PATH = download_file(ae_repo_id, ae_file)
         | 
| 95 | 
             
                LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
         | 
| 96 |  | 
| 97 | 
            -
                 | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 100 | 
            -
             | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
| 103 | 
            -
             | 
| 104 | 
            -
             | 
| 105 | 
            -
             | 
| 106 | 
            -
             | 
| 107 | 
            -
             | 
| 108 | 
            -
             | 
| 109 | 
            -
             | 
| 110 | 
            -
             | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
             | 
| 115 | 
            -
             | 
| 116 | 
            -
             | 
| 117 | 
            -
             | 
| 118 | 
            -
             | 
| 119 | 
            -
             | 
| 120 | 
            -
             | 
| 121 | 
            -
                 | 
| 122 | 
            -
             | 
| 123 | 
            -
             | 
|  | |
|  | |
| 124 |  | 
| 125 | 
             
            # Image pre-processing (resize and padding)
         | 
| 126 | 
             
            class ResizeWithPadding:
         | 
| @@ -154,10 +156,37 @@ class ResizeWithPadding: | |
| 154 | 
             
            # The function to generate image from a prompt and conditional image
         | 
| 155 | 
             
            @spaces.GPU(duration=180)
         | 
| 156 | 
             
            def infer(prompt, sample_image, recraft_model, seed=0):
         | 
| 157 | 
            -
                global model, clip_l, t5xxl, ae, lora_model
         | 
| 158 | 
            -
                if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
         | 
| 159 | 
            -
             | 
| 160 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 161 |  | 
| 162 | 
             
                model_path = model_paths[recraft_model]
         | 
| 163 | 
             
                frame_num = model_path['Frame']
         | 
|  | |
| 94 | 
             
                AE_PATH = download_file(ae_repo_id, ae_file)
         | 
| 95 | 
             
                LORA_WEIGHTS_PATH = download_file(lora_repo, lora_file)
         | 
| 96 |  | 
| 97 | 
            +
                return "Models loaded successfully. Using Recraft: {}".format(selected_model)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                # logger.info("Loading models...")
         | 
| 100 | 
            +
                # try:
         | 
| 101 | 
            +
                #     if model is None is None or clip_l is None or t5xxl is None or ae is None:
         | 
| 102 | 
            +
                #         _, model = flux_utils.load_flow_model(
         | 
| 103 | 
            +
                #             BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cpu", disable_mmap=False
         | 
| 104 | 
            +
                #         )
         | 
| 105 | 
            +
                #         clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cpu", disable_mmap=False)
         | 
| 106 | 
            +
                #         clip_l.eval()
         | 
| 107 | 
            +
                #         t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cpu", disable_mmap=False)
         | 
| 108 | 
            +
                #         t5xxl.eval()
         | 
| 109 | 
            +
                #         ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cpu", disable_mmap=False)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                #     # Load LoRA weights
         | 
| 112 | 
            +
                #     multiplier = 1.0
         | 
| 113 | 
            +
                #     weights_sd = load_file(LORA_WEIGHTS_PATH)
         | 
| 114 | 
            +
                #     lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
         | 
| 115 | 
            +
                #     lora_model.apply_to([clip_l, t5xxl], model)
         | 
| 116 | 
            +
                #     info = lora_model.load_state_dict(weights_sd, strict=True)
         | 
| 117 | 
            +
                #     logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
         | 
| 118 | 
            +
                #     lora_model.eval()
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                #     logger.info("Models loaded successfully.")
         | 
| 121 | 
            +
                #     return "Models loaded successfully. Using Recraft: {}".format(selected_model)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                # except Exception as e:
         | 
| 124 | 
            +
                #     logger.error(f"Error loading models: {e}")
         | 
| 125 | 
            +
                #     return f"Error loading models: {e}"
         | 
| 126 |  | 
| 127 | 
             
            # Image pre-processing (resize and padding)
         | 
| 128 | 
             
            class ResizeWithPadding:
         | 
|  | |
| 156 | 
             
            # The function to generate image from a prompt and conditional image
         | 
| 157 | 
             
            @spaces.GPU(duration=180)
         | 
| 158 | 
             
            def infer(prompt, sample_image, recraft_model, seed=0):
         | 
| 159 | 
            +
                # global model, clip_l, t5xxl, ae, lora_model
         | 
| 160 | 
            +
                # if model is None or lora_model is None or clip_l is None or t5xxl is None or ae is None:
         | 
| 161 | 
            +
                #     logger.error("Models not loaded. Please load the models first.")
         | 
| 162 | 
            +
                #     return None
         | 
| 163 | 
            +
                logger.info("Loading models...")
         | 
| 164 | 
            +
                try:
         | 
| 165 | 
            +
                    if model is None is None or clip_l is None or t5xxl is None or ae is None:
         | 
| 166 | 
            +
                        _, model = flux_utils.load_flow_model(
         | 
| 167 | 
            +
                            BASE_FLUX_CHECKPOINT, torch.float8_e4m3fn, "cuda", disable_mmap=False
         | 
| 168 | 
            +
                        )
         | 
| 169 | 
            +
                        clip_l = flux_utils.load_clip_l(CLIP_L_PATH, torch.bfloat16, "cuda", disable_mmap=False)
         | 
| 170 | 
            +
                        clip_l.eval()
         | 
| 171 | 
            +
                        t5xxl = flux_utils.load_t5xxl(T5XXL_PATH, torch.bfloat16, "cuda", disable_mmap=False)
         | 
| 172 | 
            +
                        t5xxl.eval()
         | 
| 173 | 
            +
                        ae = flux_utils.load_ae(AE_PATH, torch.bfloat16, "cuda", disable_mmap=False)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                    # Load LoRA weights
         | 
| 176 | 
            +
                    multiplier = 1.0
         | 
| 177 | 
            +
                    weights_sd = load_file(LORA_WEIGHTS_PATH)
         | 
| 178 | 
            +
                    lora_model, _ = lora_flux.create_network_from_weights(multiplier, None, ae, [clip_l, t5xxl], model, weights_sd, True)
         | 
| 179 | 
            +
                    lora_model.apply_to([clip_l, t5xxl], model)
         | 
| 180 | 
            +
                    info = lora_model.load_state_dict(weights_sd, strict=True)
         | 
| 181 | 
            +
                    logger.info(f"Loaded LoRA weights from {LORA_WEIGHTS_PATH}: {info}")
         | 
| 182 | 
            +
                    lora_model.eval()
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    logger.info("Models loaded successfully.")
         | 
| 185 | 
            +
                    # return "Models loaded successfully. Using Recraft: {}".format(selected_model)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                except Exception as e:
         | 
| 188 | 
            +
                    logger.error(f"Error loading models: {e}")
         | 
| 189 | 
            +
                    return f"Error loading models: {e}"
         | 
| 190 |  | 
| 191 | 
             
                model_path = model_paths[recraft_model]
         | 
| 192 | 
             
                frame_num = model_path['Frame']
         | 
