Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update ootd/inference_ootd.py
Browse files- ootd/inference_ootd.py +7 -7
    	
        ootd/inference_ootd.py
    CHANGED
    
    | @@ -33,7 +33,7 @@ MODEL_PATH = "./checkpoints/ootd" | |
| 33 | 
             
            class OOTDiffusion:
         | 
| 34 |  | 
| 35 | 
             
                def __init__(self, gpu_id):
         | 
| 36 | 
            -
                    self.gpu_id = 'cuda:' + str(gpu_id)
         | 
| 37 |  | 
| 38 | 
             
                    vae = AutoencoderKL.from_pretrained(
         | 
| 39 | 
             
                        VAE_PATH,
         | 
| @@ -64,12 +64,12 @@ class OOTDiffusion: | |
| 64 | 
             
                        use_safetensors=True,
         | 
| 65 | 
             
                        safety_checker=None,
         | 
| 66 | 
             
                        requires_safety_checker=False,
         | 
| 67 | 
            -
                    ) | 
| 68 |  | 
| 69 | 
             
                    self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
         | 
| 70 |  | 
| 71 | 
             
                    self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
         | 
| 72 | 
            -
                    self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH) | 
| 73 |  | 
| 74 | 
             
                    self.tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 75 | 
             
                        MODEL_PATH,
         | 
| @@ -78,7 +78,7 @@ class OOTDiffusion: | |
| 78 | 
             
                    self.text_encoder = CLIPTextModel.from_pretrained(
         | 
| 79 | 
             
                        MODEL_PATH,
         | 
| 80 | 
             
                        subfolder="text_encoder",
         | 
| 81 | 
            -
                    ) | 
| 82 |  | 
| 83 |  | 
| 84 | 
             
                def tokenize_captions(self, captions, max_length):
         | 
| @@ -107,14 +107,14 @@ class OOTDiffusion: | |
| 107 | 
             
                    generator = torch.manual_seed(seed)
         | 
| 108 |  | 
| 109 | 
             
                    with torch.no_grad():
         | 
| 110 | 
            -
                        prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to( | 
| 111 | 
             
                        prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
         | 
| 112 | 
             
                        prompt_image = prompt_image.unsqueeze(1)
         | 
| 113 | 
             
                        if model_type == 'hd':
         | 
| 114 | 
            -
                            prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to( | 
| 115 | 
             
                            prompt_embeds[:, 1:] = prompt_image[:]
         | 
| 116 | 
             
                        elif model_type == 'dc':
         | 
| 117 | 
            -
                            prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to( | 
| 118 | 
             
                            prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
         | 
| 119 | 
             
                        else:
         | 
| 120 | 
             
                            raise ValueError("model_type must be \'hd\' or \'dc\'!")
         | 
|  | |
| 33 | 
             
            class OOTDiffusion:
         | 
| 34 |  | 
| 35 | 
             
                def __init__(self, gpu_id):
         | 
| 36 | 
            +
                    # self.gpu_id = 'cuda:' + str(gpu_id)
         | 
| 37 |  | 
| 38 | 
             
                    vae = AutoencoderKL.from_pretrained(
         | 
| 39 | 
             
                        VAE_PATH,
         | 
|  | |
| 64 | 
             
                        use_safetensors=True,
         | 
| 65 | 
             
                        safety_checker=None,
         | 
| 66 | 
             
                        requires_safety_checker=False,
         | 
| 67 | 
            +
                    )#.to(self.gpu_id)
         | 
| 68 |  | 
| 69 | 
             
                    self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
         | 
| 70 |  | 
| 71 | 
             
                    self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
         | 
| 72 | 
            +
                    self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)#.to(self.gpu_id)
         | 
| 73 |  | 
| 74 | 
             
                    self.tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 75 | 
             
                        MODEL_PATH,
         | 
|  | |
| 78 | 
             
                    self.text_encoder = CLIPTextModel.from_pretrained(
         | 
| 79 | 
             
                        MODEL_PATH,
         | 
| 80 | 
             
                        subfolder="text_encoder",
         | 
| 81 | 
            +
                    )#.to(self.gpu_id)
         | 
| 82 |  | 
| 83 |  | 
| 84 | 
             
                def tokenize_captions(self, captions, max_length):
         | 
|  | |
| 107 | 
             
                    generator = torch.manual_seed(seed)
         | 
| 108 |  | 
| 109 | 
             
                    with torch.no_grad():
         | 
| 110 | 
            +
                        prompt_image = self.auto_processor(images=image_garm, return_tensors="pt").to('cuda')
         | 
| 111 | 
             
                        prompt_image = self.image_encoder(prompt_image.data['pixel_values']).image_embeds
         | 
| 112 | 
             
                        prompt_image = prompt_image.unsqueeze(1)
         | 
| 113 | 
             
                        if model_type == 'hd':
         | 
| 114 | 
            +
                            prompt_embeds = self.text_encoder(self.tokenize_captions([""], 2).to('cuda'))[0]
         | 
| 115 | 
             
                            prompt_embeds[:, 1:] = prompt_image[:]
         | 
| 116 | 
             
                        elif model_type == 'dc':
         | 
| 117 | 
            +
                            prompt_embeds = self.text_encoder(self.tokenize_captions([category], 3).to('cuda'))[0]
         | 
| 118 | 
             
                            prompt_embeds = torch.cat([prompt_embeds, prompt_image], dim=1)
         | 
| 119 | 
             
                        else:
         | 
| 120 | 
             
                            raise ValueError("model_type must be \'hd\' or \'dc\'!")
         | 
