| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@@ -21,7 +21,7 @@ class VQGANTextureAwareSpatialHierarchyInferenceModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def __init__(self, opt): | 
					
					
						
						| 
							 | 
						         self.opt = opt | 
					
					
						
						| 
							 | 
						-        self.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						+        self.device = torch.device(opt['device']) | 
					
					
						
						| 
							 | 
						         self.is_train = opt['is_train'] | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         self.top_encoder = Encoder( | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@@ -20,7 +20,7 @@ class HierarchyVQSpatialTextureAwareModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def __init__(self, opt): | 
					
					
						
						| 
							 | 
						         self.opt = opt | 
					
					
						
						| 
							 | 
						-        self.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						+        self.device = torch.device(opt['device']) | 
					
					
						
						| 
							 | 
						         self.top_encoder = Encoder( | 
					
					
						
						| 
							 | 
						             ch=opt['top_ch'], | 
					
					
						
						| 
							 | 
						             num_res_blocks=opt['top_num_res_blocks'], | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@@ -22,7 +22,7 @@ class ParsingGenModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def __init__(self, opt): | 
					
					
						
						| 
							 | 
						         self.opt = opt | 
					
					
						
						| 
							 | 
						-        self.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						+        self.device = torch.device(opt['device']) | 
					
					
						
						| 
							 | 
						         self.is_train = opt['is_train'] | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         self.attr_embedder = ShapeAttrEmbedding( | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@@ -23,7 +23,7 @@ class BaseSampleModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def __init__(self, opt): | 
					
					
						
						| 
							 | 
						         self.opt = opt | 
					
					
						
						| 
							 | 
						-        self.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						+        self.device = torch.device(opt['device']) | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         # hierarchical VQVAE | 
					
					
						
						| 
							 | 
						         self.decoder = Decoder( | 
					
					
						
						| 
							 | 
						@@ -123,7 +123,7 @@ class BaseSampleModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def load_top_pretrain_models(self): | 
					
					
						
						| 
							 | 
						         # load pretrained vqgan | 
					
					
						
						| 
							 | 
						-        top_vae_checkpoint = torch.load(self.opt['top_vae_path']) | 
					
					
						
						| 
							 | 
						+        top_vae_checkpoint = torch.load(self.opt['top_vae_path'], map_location=self.device) | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         self.decoder.load_state_dict( | 
					
					
						
						| 
							 | 
						             top_vae_checkpoint['decoder'], strict=True) | 
					
					
						
						| 
							 | 
						@@ -137,7 +137,7 @@ class BaseSampleModel(): | 
					
					
						
						| 
							 | 
						         self.top_post_quant_conv.eval() | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def load_bot_pretrain_network(self): | 
					
					
						
						| 
							 | 
						-        checkpoint = torch.load(self.opt['bot_vae_path']) | 
					
					
						
						| 
							 | 
						+        checkpoint = torch.load(self.opt['bot_vae_path'], map_location=self.device) | 
					
					
						
						| 
							 | 
						         self.bot_decoder_res.load_state_dict( | 
					
					
						
						| 
							 | 
						             checkpoint['bot_decoder_res'], strict=True) | 
					
					
						
						| 
							 | 
						         self.decoder.load_state_dict(checkpoint['decoder'], strict=True) | 
					
					
						
						| 
							 | 
						@@ -153,7 +153,7 @@ class BaseSampleModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def load_pretrained_segm_token(self): | 
					
					
						
						| 
							 | 
						         # load pretrained vqgan for segmentation mask | 
					
					
						
						| 
							 | 
						-        segm_token_checkpoint = torch.load(self.opt['segm_token_path']) | 
					
					
						
						| 
							 | 
						+        segm_token_checkpoint = torch.load(self.opt['segm_token_path'], map_location=self.device) | 
					
					
						
						| 
							 | 
						         self.segm_encoder.load_state_dict( | 
					
					
						
						| 
							 | 
						             segm_token_checkpoint['encoder'], strict=True) | 
					
					
						
						| 
							 | 
						         self.segm_quantizer.load_state_dict( | 
					
					
						
						| 
							 | 
						@@ -166,7 +166,7 @@ class BaseSampleModel(): | 
					
					
						
						| 
							 | 
						         self.segm_quant_conv.eval() | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def load_index_pred_network(self): | 
					
					
						
						| 
							 | 
						-        checkpoint = torch.load(self.opt['pretrained_index_network']) | 
					
					
						
						| 
							 | 
						+        checkpoint = torch.load(self.opt['pretrained_index_network'], map_location=self.device) | 
					
					
						
						| 
							 | 
						         self.index_pred_guidance_encoder.load_state_dict( | 
					
					
						
						| 
							 | 
						             checkpoint['guidance_encoder'], strict=True) | 
					
					
						
						| 
							 | 
						         self.index_pred_decoder.load_state_dict( | 
					
					
						
						| 
							 | 
						@@ -176,7 +176,7 @@ class BaseSampleModel(): | 
					
					
						
						| 
							 | 
						         self.index_pred_decoder.eval() | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def load_sampler_pretrained_network(self): | 
					
					
						
						| 
							 | 
						-        checkpoint = torch.load(self.opt['pretrained_sampler']) | 
					
					
						
						| 
							 | 
						+        checkpoint = torch.load(self.opt['pretrained_sampler'], map_location=self.device) | 
					
					
						
						| 
							 | 
						         self.sampler_fn.load_state_dict(checkpoint, strict=True) | 
					
					
						
						| 
							 | 
						         self.sampler_fn.eval() | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						@@ -397,7 +397,7 @@ class SampleFromPoseModel(BaseSampleModel): | 
					
					
						
						| 
							 | 
						                         [185, 210, 205], [130, 165, 180], [225, 141, 151]] | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def load_shape_generation_models(self): | 
					
					
						
						| 
							 | 
						-        checkpoint = torch.load(self.opt['pretrained_parsing_gen']) | 
					
					
						
						| 
							 | 
						+        checkpoint = torch.load(self.opt['pretrained_parsing_gen'], map_location=self.device) | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         self.shape_attr_embedder.load_state_dict( | 
					
					
						
						| 
							 | 
						             checkpoint['embedder'], strict=True) | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@@ -21,7 +21,7 @@ class TransformerTextureAwareModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def __init__(self, opt): | 
					
					
						
						| 
							 | 
						         self.opt = opt | 
					
					
						
						| 
							 | 
						-        self.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						+        self.device = torch.device(opt['device']) | 
					
					
						
						| 
							 | 
						         self.is_train = opt['is_train'] | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         # VQVAE for image | 
					
					
						
						| 
							 | 
						@@ -317,10 +317,10 @@ class TransformerTextureAwareModel(): | 
					
					
						
						| 
							 | 
						     def sample_fn(self, temp=1.0, sample_steps=None): | 
					
					
						
						| 
							 | 
						         self._denoise_fn.eval() | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						-        b, device = self.image.size(0), 'cuda' | 
					
					
						
						| 
							 | 
						+        b = self.image.size(0) | 
					
					
						
						| 
							 | 
						         x_t = torch.ones( | 
					
					
						
						| 
							 | 
						-            (b, np.prod(self.shape)), device=device).long() * self.mask_id | 
					
					
						
						| 
							 | 
						-        unmasked = torch.zeros_like(x_t, device=device).bool() | 
					
					
						
						| 
							 | 
						+            (b, np.prod(self.shape)), device=self.device).long() * self.mask_id | 
					
					
						
						| 
							 | 
						+        unmasked = torch.zeros_like(x_t, device=self.device).bool() | 
					
					
						
						| 
							 | 
						         sample_steps = list(range(1, sample_steps + 1)) | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         texture_mask_flatten = self.texture_tokens.view(-1) | 
					
					
						
						| 
							 | 
						@@ -336,11 +336,11 @@ class TransformerTextureAwareModel(): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						         for t in reversed(sample_steps): | 
					
					
						
						| 
							 | 
						             print(f'Sample timestep {t:4d}', end='\r') | 
					
					
						
						| 
							 | 
						-            t = torch.full((b, ), t, device=device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						+            t = torch.full((b, ), t, device=self.device, dtype=torch.long) | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						             # where to unmask | 
					
					
						
						| 
							 | 
						             changes = torch.rand( | 
					
					
						
						| 
							 | 
						-                x_t.shape, device=device) < 1 / t.float().unsqueeze(-1) | 
					
					
						
						| 
							 | 
						+                x_t.shape, device=self.device) < 1 / t.float().unsqueeze(-1) | 
					
					
						
						| 
							 | 
						             # don't unmask somewhere already unmasked | 
					
					
						
						| 
							 | 
						             changes = torch.bitwise_xor(changes, | 
					
					
						
						| 
							 | 
						                                         torch.bitwise_and(changes, unmasked)) | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						 | 
					
					
						
						| 
							 | 
						@@ -20,7 +20,7 @@ class VQModel(): | 
					
					
						
						| 
							 | 
						     def __init__(self, opt): | 
					
					
						
						| 
							 | 
						         super().__init__() | 
					
					
						
						| 
							 | 
						         self.opt = opt | 
					
					
						
						| 
							 | 
						-        self.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						+        self.device = torch.device(opt['device']) | 
					
					
						
						| 
							 | 
						         self.encoder = Encoder( | 
					
					
						
						| 
							 | 
						             ch=opt['ch'], | 
					
					
						
						| 
							 | 
						             num_res_blocks=opt['num_res_blocks'], | 
					
					
						
						| 
							 | 
						@@ -390,7 +390,7 @@ class VQImageSegmTextureModel(VQImageModel): | 
					
					
						
						| 
							 | 
						  | 
					
					
						
						| 
							 | 
						     def __init__(self, opt): | 
					
					
						
						| 
							 | 
						         self.opt = opt | 
					
					
						
						| 
							 | 
						-        self.device = torch.device('cuda') | 
					
					
						
						| 
							 | 
						+        self.device = torch.device(opt['device']) | 
					
					
						
						| 
							 | 
						         self.encoder = Encoder( | 
					
					
						
						| 
							 | 
						             ch=opt['ch'], | 
					
					
						
						| 
							 | 
						             num_res_blocks=opt['num_res_blocks'], | 
					
					
						
						| 
							 | 
						
 |