Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	bugfix/Add default pipe
Browse files
    	
        model.py
    CHANGED
    
    | @@ -48,7 +48,7 @@ class Model: | |
| 48 | 
             
                    self.model_name = ""
         | 
| 49 |  | 
| 50 | 
             
                def set_model(self, model_type: ModelType, model_id: str, **kwargs):
         | 
| 51 | 
            -
                    if self.pipe is not None:
         | 
| 52 | 
             
                        del self.pipe
         | 
| 53 | 
             
                    torch.cuda.empty_cache()
         | 
| 54 | 
             
                    gc.collect()
         | 
| @@ -59,7 +59,7 @@ class Model: | |
| 59 | 
             
                    self.model_name = model_id
         | 
| 60 |  | 
| 61 | 
             
                def inference_chunk(self, frame_ids, **kwargs):
         | 
| 62 | 
            -
                    if self.pipe is None:
         | 
| 63 | 
             
                        return
         | 
| 64 |  | 
| 65 | 
             
                    prompt = np.array(kwargs.pop('prompt'))
         | 
| @@ -80,15 +80,14 @@ class Model: | |
| 80 | 
             
                                     **kwargs)
         | 
| 81 |  | 
| 82 | 
             
                def inference(self, split_to_chunks=False, chunk_size=8, **kwargs):
         | 
| 83 | 
            -
                    if self.pipe is None:
         | 
| 84 | 
             
                        return
         | 
| 85 | 
            -
             | 
| 86 | 
             
                    if "merging_ratio" in kwargs:
         | 
| 87 | 
             
                        merging_ratio = kwargs.pop("merging_ratio")
         | 
| 88 |  | 
| 89 | 
            -
                        if merging_ratio > 0:
         | 
| 90 | 
            -
             | 
| 91 | 
            -
                            tomesd.apply_patch(self.pipe, ratio=merging_ratio)
         | 
| 92 | 
             
                    seed = kwargs.pop('seed', 0)
         | 
| 93 | 
             
                    if seed < 0:
         | 
| 94 | 
             
                        seed = self.generator.seed()
         | 
| @@ -144,7 +143,7 @@ class Model: | |
| 144 | 
             
                                             resolution=512,
         | 
| 145 | 
             
                                             use_cf_attn=True,
         | 
| 146 | 
             
                                             save_path=None):
         | 
| 147 | 
            -
                    print(" | 
| 148 | 
             
                    video_path = gradio_utils.edge_path_to_video_path(video_path)
         | 
| 149 | 
             
                    if self.model_type != ModelType.ControlNetCanny:
         | 
| 150 | 
             
                        controlnet = ControlNetModel.from_pretrained(
         | 
| @@ -203,7 +202,7 @@ class Model: | |
| 203 | 
             
                                            resolution=512,
         | 
| 204 | 
             
                                            use_cf_attn=True,
         | 
| 205 | 
             
                                            save_path=None):
         | 
| 206 | 
            -
                    print(" | 
| 207 | 
             
                    video_path = gradio_utils.motion_to_video_path(video_path)
         | 
| 208 | 
             
                    if self.model_type != ModelType.ControlNetPose:
         | 
| 209 | 
             
                        controlnet = ControlNetModel.from_pretrained(
         | 
| @@ -268,7 +267,7 @@ class Model: | |
| 268 | 
             
                                                resolution=512,
         | 
| 269 | 
             
                                                use_cf_attn=True,
         | 
| 270 | 
             
                                                save_path=None):
         | 
| 271 | 
            -
                    print(" | 
| 272 | 
             
                    db_path = gradio_utils.get_model_from_db_selection(db_path)
         | 
| 273 | 
             
                    video_path = gradio_utils.get_video_from_canny_selection(video_path)
         | 
| 274 | 
             
                    # Load db and controlnet weights
         | 
| @@ -331,7 +330,7 @@ class Model: | |
| 331 | 
             
                                    merging_ratio=0.0,
         | 
| 332 | 
             
                                    use_cf_attn=True,
         | 
| 333 | 
             
                                    save_path=None,):
         | 
| 334 | 
            -
                    print(" | 
| 335 | 
             
                    if self.model_type != ModelType.Pix2Pix_Video:
         | 
| 336 | 
             
                        self.set_model(ModelType.Pix2Pix_Video,
         | 
| 337 | 
             
                                       model_id="timbrooks/instruct-pix2pix")
         | 
| @@ -375,7 +374,7 @@ class Model: | |
| 375 | 
             
                                       smooth_bg=False,
         | 
| 376 | 
             
                                       smooth_bg_strength=0.4,
         | 
| 377 | 
             
                                       path=None):
         | 
| 378 | 
            -
                    print(" | 
| 379 | 
             
                    if self.model_type != ModelType.Text2Video or model_name != self.model_name:
         | 
| 380 | 
             
                        print("Model update")
         | 
| 381 | 
             
                        unet = UNet2DConditionModel.from_pretrained(
         | 
|  | |
| 48 | 
             
                    self.model_name = ""
         | 
| 49 |  | 
| 50 | 
             
                def set_model(self, model_type: ModelType, model_id: str, **kwargs):
         | 
| 51 | 
            +
                    if hasattr(self, "pipe") and self.pipe is not None:
         | 
| 52 | 
             
                        del self.pipe
         | 
| 53 | 
             
                    torch.cuda.empty_cache()
         | 
| 54 | 
             
                    gc.collect()
         | 
|  | |
| 59 | 
             
                    self.model_name = model_id
         | 
| 60 |  | 
| 61 | 
             
                def inference_chunk(self, frame_ids, **kwargs):
         | 
| 62 | 
            +
                    if not hasattr(self, "pipe") or self.pipe is None:
         | 
| 63 | 
             
                        return
         | 
| 64 |  | 
| 65 | 
             
                    prompt = np.array(kwargs.pop('prompt'))
         | 
|  | |
| 80 | 
             
                                     **kwargs)
         | 
| 81 |  | 
| 82 | 
             
                def inference(self, split_to_chunks=False, chunk_size=8, **kwargs):
         | 
| 83 | 
            +
                    if not hasattr(self, "pipe") or self.pipe is None:
         | 
| 84 | 
             
                        return
         | 
| 85 | 
            +
             | 
| 86 | 
             
                    if "merging_ratio" in kwargs:
         | 
| 87 | 
             
                        merging_ratio = kwargs.pop("merging_ratio")
         | 
| 88 |  | 
| 89 | 
            +
                        # if merging_ratio > 0:
         | 
| 90 | 
            +
                        tomesd.apply_patch(self.pipe, ratio=merging_ratio)
         | 
|  | |
| 91 | 
             
                    seed = kwargs.pop('seed', 0)
         | 
| 92 | 
             
                    if seed < 0:
         | 
| 93 | 
             
                        seed = self.generator.seed()
         | 
|  | |
| 143 | 
             
                                             resolution=512,
         | 
| 144 | 
             
                                             use_cf_attn=True,
         | 
| 145 | 
             
                                             save_path=None):
         | 
| 146 | 
            +
                    print("Module Canny")
         | 
| 147 | 
             
                    video_path = gradio_utils.edge_path_to_video_path(video_path)
         | 
| 148 | 
             
                    if self.model_type != ModelType.ControlNetCanny:
         | 
| 149 | 
             
                        controlnet = ControlNetModel.from_pretrained(
         | 
|  | |
| 202 | 
             
                                            resolution=512,
         | 
| 203 | 
             
                                            use_cf_attn=True,
         | 
| 204 | 
             
                                            save_path=None):
         | 
| 205 | 
            +
                    print("Module Pose")
         | 
| 206 | 
             
                    video_path = gradio_utils.motion_to_video_path(video_path)
         | 
| 207 | 
             
                    if self.model_type != ModelType.ControlNetPose:
         | 
| 208 | 
             
                        controlnet = ControlNetModel.from_pretrained(
         | 
|  | |
| 267 | 
             
                                                resolution=512,
         | 
| 268 | 
             
                                                use_cf_attn=True,
         | 
| 269 | 
             
                                                save_path=None):
         | 
| 270 | 
            +
                    print("Module Canny_DB")
         | 
| 271 | 
             
                    db_path = gradio_utils.get_model_from_db_selection(db_path)
         | 
| 272 | 
             
                    video_path = gradio_utils.get_video_from_canny_selection(video_path)
         | 
| 273 | 
             
                    # Load db and controlnet weights
         | 
|  | |
| 330 | 
             
                                    merging_ratio=0.0,
         | 
| 331 | 
             
                                    use_cf_attn=True,
         | 
| 332 | 
             
                                    save_path=None,):
         | 
| 333 | 
            +
                    print("Module Pix2Pix")
         | 
| 334 | 
             
                    if self.model_type != ModelType.Pix2Pix_Video:
         | 
| 335 | 
             
                        self.set_model(ModelType.Pix2Pix_Video,
         | 
| 336 | 
             
                                       model_id="timbrooks/instruct-pix2pix")
         | 
|  | |
| 374 | 
             
                                       smooth_bg=False,
         | 
| 375 | 
             
                                       smooth_bg_strength=0.4,
         | 
| 376 | 
             
                                       path=None):
         | 
| 377 | 
            +
                    print("Module Text2Video")
         | 
| 378 | 
             
                    if self.model_type != ModelType.Text2Video or model_name != self.model_name:
         | 
| 379 | 
             
                        print("Model update")
         | 
| 380 | 
             
                        unet = UNet2DConditionModel.from_pretrained(
         | 
 
			

