Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							Β·
						
						929fb7f
	
1
								Parent(s):
							
							c72201c
								
remove pl
Browse files- ldm/models/autoencoder.py +2 -2
- ldm/models/diffusion/ddpm.py +5 -5
- requirements.txt +1 -2
    	
        ldm/models/autoencoder.py
    CHANGED
    
    | @@ -1,5 +1,5 @@ | |
| 1 | 
             
            import torch
         | 
| 2 | 
            -
            import pytorch_lightning as pl
         | 
| 3 | 
             
            import torch.nn.functional as F
         | 
| 4 | 
             
            from contextlib import contextmanager
         | 
| 5 |  | 
| @@ -10,7 +10,7 @@ from ldm.util import instantiate_from_config | |
| 10 | 
             
            from ldm.modules.ema import LitEma
         | 
| 11 |  | 
| 12 |  | 
| 13 | 
            -
            class AutoencoderKL( | 
| 14 | 
             
                def __init__(self,
         | 
| 15 | 
             
                             ddconfig,
         | 
| 16 | 
             
                             lossconfig,
         | 
|  | |
| 1 | 
             
            import torch
         | 
| 2 | 
            +
            # import pytorch_lightning as pl
         | 
| 3 | 
             
            import torch.nn.functional as F
         | 
| 4 | 
             
            from contextlib import contextmanager
         | 
| 5 |  | 
|  | |
| 10 | 
             
            from ldm.modules.ema import LitEma
         | 
| 11 |  | 
| 12 |  | 
| 13 | 
            +
            class AutoencoderKL(nn.Module):
         | 
| 14 | 
             
                def __init__(self,
         | 
| 15 | 
             
                             ddconfig,
         | 
| 16 | 
             
                             lossconfig,
         | 
    	
        ldm/models/diffusion/ddpm.py
    CHANGED
    
    | @@ -9,7 +9,7 @@ https://github.com/CompVis/taming-transformers | |
| 9 | 
             
            import torch
         | 
| 10 | 
             
            import torch.nn as nn
         | 
| 11 | 
             
            import numpy as np
         | 
| 12 | 
            -
            import pytorch_lightning as pl
         | 
| 13 | 
             
            from torch.optim.lr_scheduler import LambdaLR
         | 
| 14 | 
             
            from einops import rearrange, repeat
         | 
| 15 | 
             
            from contextlib import contextmanager, nullcontext
         | 
| @@ -17,7 +17,7 @@ from functools import partial | |
| 17 | 
             
            import itertools
         | 
| 18 | 
             
            from tqdm import tqdm
         | 
| 19 | 
             
            from torchvision.utils import make_grid
         | 
| 20 | 
            -
            from pytorch_lightning.utilities.distributed import rank_zero_only
         | 
| 21 | 
             
            from omegaconf import ListConfig
         | 
| 22 | 
             
            from torchvision.transforms.functional import resize
         | 
| 23 | 
             
            import torchvision.transforms as T
         | 
| @@ -47,7 +47,7 @@ def disabled_train(self, mode=True): | |
| 47 | 
             
            def uniform_on_device(r1, r2, shape, device):
         | 
| 48 | 
             
                return (r1 - r2) * torch.rand(*shape, device=device) + r2
         | 
| 49 |  | 
| 50 | 
            -
            class DDPM( | 
| 51 | 
             
                # classic DDPM with Gaussian diffusion, in image space
         | 
| 52 | 
             
                def __init__(self,
         | 
| 53 | 
             
                             unet_config,
         | 
| @@ -614,7 +614,7 @@ class LatentDiffusion(DDPM): | |
| 614 | 
             
                    ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
         | 
| 615 | 
             
                    self.cond_ids[:self.num_timesteps_cond] = ids
         | 
| 616 |  | 
| 617 | 
            -
                @rank_zero_only
         | 
| 618 | 
             
                @torch.no_grad()
         | 
| 619 | 
             
                def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
         | 
| 620 | 
             
                    # only for very first batch
         | 
| @@ -1387,7 +1387,7 @@ class LatentDiffusion(DDPM): | |
| 1387 | 
             
                    return x
         | 
| 1388 |  | 
| 1389 |  | 
| 1390 | 
            -
            class DiffusionWrapper( | 
| 1391 | 
             
                def __init__(self, diff_model_config, conditioning_key):
         | 
| 1392 | 
             
                    super().__init__()
         | 
| 1393 | 
             
                    self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
         | 
|  | |
| 9 | 
             
            import torch
         | 
| 10 | 
             
            import torch.nn as nn
         | 
| 11 | 
             
            import numpy as np
         | 
| 12 | 
            +
            # import pytorch_lightning as pl
         | 
| 13 | 
             
            from torch.optim.lr_scheduler import LambdaLR
         | 
| 14 | 
             
            from einops import rearrange, repeat
         | 
| 15 | 
             
            from contextlib import contextmanager, nullcontext
         | 
|  | |
| 17 | 
             
            import itertools
         | 
| 18 | 
             
            from tqdm import tqdm
         | 
| 19 | 
             
            from torchvision.utils import make_grid
         | 
| 20 | 
            +
            # from pytorch_lightning.utilities.distributed import rank_zero_only
         | 
| 21 | 
             
            from omegaconf import ListConfig
         | 
| 22 | 
             
            from torchvision.transforms.functional import resize
         | 
| 23 | 
             
            import torchvision.transforms as T
         | 
|  | |
| 47 | 
             
            def uniform_on_device(r1, r2, shape, device):
         | 
| 48 | 
             
                return (r1 - r2) * torch.rand(*shape, device=device) + r2
         | 
| 49 |  | 
| 50 | 
            +
            class DDPM(nn.Module):
         | 
| 51 | 
             
                # classic DDPM with Gaussian diffusion, in image space
         | 
| 52 | 
             
                def __init__(self,
         | 
| 53 | 
             
                             unet_config,
         | 
|  | |
| 614 | 
             
                    ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long()
         | 
| 615 | 
             
                    self.cond_ids[:self.num_timesteps_cond] = ids
         | 
| 616 |  | 
| 617 | 
            +
                # @rank_zero_only
         | 
| 618 | 
             
                @torch.no_grad()
         | 
| 619 | 
             
                def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
         | 
| 620 | 
             
                    # only for very first batch
         | 
|  | |
| 1387 | 
             
                    return x
         | 
| 1388 |  | 
| 1389 |  | 
| 1390 | 
            +
            class DiffusionWrapper(nn.Module):
         | 
| 1391 | 
             
                def __init__(self, diff_model_config, conditioning_key):
         | 
| 1392 | 
             
                    super().__init__()
         | 
| 1393 | 
             
                    self.sequential_cross_attn = diff_model_config.pop("sequential_crossattn", False)
         | 
    	
        requirements.txt
    CHANGED
    
    | @@ -20,5 +20,4 @@ cloudpickle | |
| 20 | 
             
            fvcore
         | 
| 21 | 
             
            omegaconf==2.1
         | 
| 22 | 
             
            hydra-core
         | 
| 23 | 
            -
            pycocotools
         | 
| 24 | 
            -
            pytorch-lightning==1.5.0
         | 
|  | |
| 20 | 
             
            fvcore
         | 
| 21 | 
             
            omegaconf==2.1
         | 
| 22 | 
             
            hydra-core
         | 
| 23 | 
            +
            pycocotools
         | 
|  |