Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Update train/train_t2i.py
Browse files- train/train_t2i.py +806 -807
    	
        train/train_t2i.py
    CHANGED
    
    | @@ -1,807 +1,806 @@ | |
| 1 | 
            -
            import torch
         | 
| 2 | 
            -
            import json
         | 
| 3 | 
            -
            import yaml
         | 
| 4 | 
            -
            import torchvision
         | 
| 5 | 
            -
            from torch import nn, optim
         | 
| 6 | 
            -
            from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
         | 
| 7 | 
            -
            from warmup_scheduler import GradualWarmupScheduler
         | 
| 8 | 
            -
            import torch.multiprocessing as mp
         | 
| 9 | 
            -
            import numpy as np
         | 
| 10 | 
            -
            import os
         | 
| 11 | 
            -
            import sys
         | 
| 12 | 
            -
            sys.path.append(os.path.abspath('./'))
         | 
| 13 | 
            -
            from dataclasses import dataclass
         | 
| 14 | 
            -
            from torch.distributed import init_process_group, destroy_process_group, barrier
         | 
| 15 | 
            -
            from gdf import GDF_dual_fixlrt as GDF
         | 
| 16 | 
            -
            from gdf import EpsilonTarget, CosineSchedule
         | 
| 17 | 
            -
            from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
         | 
| 18 | 
            -
            from torchtools.transforms import SmartCrop
         | 
| 19 | 
            -
            from fractions import Fraction
         | 
| 20 | 
            -
            from modules.effnet import EfficientNetEncoder
         | 
| 21 | 
            -
             | 
| 22 | 
            -
            from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
         | 
| 23 | 
            -
            from modules.previewer import Previewer
         | 
| 24 | 
            -
            from core.data import Bucketeer
         | 
| 25 | 
            -
            from train.base import DataCore, TrainingCore
         | 
| 26 | 
            -
            from tqdm import tqdm
         | 
| 27 | 
            -
            from core import WarpCore
         | 
| 28 | 
            -
            from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
         | 
| 29 | 
            -
             | 
| 30 | 
            -
            from accelerate import init_empty_weights
         | 
| 31 | 
            -
            from accelerate.utils import set_module_tensor_to_device
         | 
| 32 | 
            -
            from contextlib import contextmanager
         | 
| 33 | 
            -
            from train.dist_core import *
         | 
| 34 | 
            -
            import glob
         | 
| 35 | 
            -
            from torch.utils.data import DataLoader, Dataset
         | 
| 36 | 
            -
            from torch.nn.parallel import DistributedDataParallel as DDP
         | 
| 37 | 
            -
            from torch.utils.data.distributed import DistributedSampler
         | 
| 38 | 
            -
            from PIL import Image
         | 
| 39 | 
            -
            from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
         | 
| 40 | 
            -
            from core.utils import Base
         | 
| 41 | 
            -
            from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm
         | 
| 42 | 
            -
            import torch.nn.functional as F
         | 
| 43 | 
            -
            import functools
         | 
| 44 | 
            -
            import math
         | 
| 45 | 
            -
            import copy
         | 
| 46 | 
            -
            import random
         | 
| 47 | 
            -
            from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
         | 
| 48 | 
            -
            Image.MAX_IMAGE_PIXELS = None
         | 
| 49 | 
            -
            torch.manual_seed(23)
         | 
| 50 | 
            -
            random.seed(23)
         | 
| 51 | 
            -
            np.random.seed(23)
         | 
| 52 | 
            -
            #7978026
         | 
| 53 | 
            -
             | 
| 54 | 
            -
            class Null_Model(torch.nn.Module):
         | 
| 55 | 
            -
                def __init__(self):
         | 
| 56 | 
            -
                    super().__init__()
         | 
| 57 | 
            -
                def forward(self, x):
         | 
| 58 | 
            -
                    pass
         | 
| 59 | 
            -
             | 
| 60 | 
            -
             | 
| 61 | 
            -
             | 
| 62 | 
            -
             | 
| 63 | 
            -
            def identity(x):
         | 
| 64 | 
            -
                if isinstance(x, bytes):
         | 
| 65 | 
            -
                    x = x.decode('utf-8')
         | 
| 66 | 
            -
                return x
         | 
| 67 | 
            -
            def check_nan_inmodel(model, meta=''):
         | 
| 68 | 
            -
                    for name, param in model.named_parameters():
         | 
| 69 | 
            -
                        if torch.isnan(param).any():
         | 
| 70 | 
            -
                            print(f"nan detected in {name}", meta)
         | 
| 71 | 
            -
                            return True
         | 
| 72 | 
            -
                    print('no nan', meta)
         | 
| 73 | 
            -
                    return False  
         | 
| 74 | 
            -
            class mydist_dataset(Dataset):
         | 
| 75 | 
            -
                def __init__(self, rootpath, img_processor=None):
         | 
| 76 | 
            -
             | 
| 77 | 
            -
                    self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg'))
         | 
| 78 | 
            -
                    self.img_processor = img_processor
         | 
| 79 | 
            -
                    self.length = len( self.img_pathlist)
         | 
| 80 | 
            -
             | 
| 81 | 
            -
                  
         | 
| 82 | 
            -
                  
         | 
| 83 | 
            -
                def __getitem__(self, idx):
         | 
| 84 | 
            -
                    
         | 
| 85 | 
            -
                    imgpath = self.img_pathlist[idx]
         | 
| 86 | 
            -
                    json_file = imgpath.replace('.jpg', '.json') 
         | 
| 87 | 
            -
                   
         | 
| 88 | 
            -
                    with open(json_file, 'r') as file:
         | 
| 89 | 
            -
                        info = json.load(file)
         | 
| 90 | 
            -
                    txt = info['caption']
         | 
| 91 | 
            -
                    if txt is None:
         | 
| 92 | 
            -
                        txt = ' ' 
         | 
| 93 | 
            -
                    try:  
         | 
| 94 | 
            -
                      img = Image.open(imgpath).convert('RGB')
         | 
| 95 | 
            -
                      w, h = img.size
         | 
| 96 | 
            -
                      if self.img_processor is not None:
         | 
| 97 | 
            -
                        img = self.img_processor(img)
         | 
| 98 | 
            -
             | 
| 99 | 
            -
                    except:
         | 
| 100 | 
            -
                      print('exception', imgpath)
         | 
| 101 | 
            -
                      return self.__getitem__(random.randint(0, self.length -1 ) )
         | 
| 102 | 
            -
                    return dict(captions=txt, images=img)
         | 
| 103 | 
            -
                def __len__(self):
         | 
| 104 | 
            -
                    return self.length
         | 
| 105 | 
            -
             | 
| 106 | 
            -
            class WurstCore(TrainingCore, DataCore, WarpCore):
         | 
| 107 | 
            -
                @dataclass(frozen=True)
         | 
| 108 | 
            -
                class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
         | 
| 109 | 
            -
                    # TRAINING PARAMS
         | 
| 110 | 
            -
                    lr: float = EXPECTED_TRAIN
         | 
| 111 | 
            -
                    warmup_updates: int = EXPECTED_TRAIN
         | 
| 112 | 
            -
                    dtype: str = None
         | 
| 113 | 
            -
             | 
| 114 | 
            -
                    # MODEL VERSION
         | 
| 115 | 
            -
                    model_version: str = EXPECTED  # 3.6B or 1B
         | 
| 116 | 
            -
                    clip_image_model_name: str = 'openai/clip-vit-large-patch14'
         | 
| 117 | 
            -
                    clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
         | 
| 118 | 
            -
                 
         | 
| 119 | 
            -
                    # CHECKPOINT PATHS
         | 
| 120 | 
            -
                    effnet_checkpoint_path: str = EXPECTED
         | 
| 121 | 
            -
                    previewer_checkpoint_path: str = EXPECTED
         | 
| 122 | 
            -
                   
         | 
| 123 | 
            -
                    generator_checkpoint_path: str = None
         | 
| 124 | 
            -
             | 
| 125 | 
            -
                    # gdf customization
         | 
| 126 | 
            -
                    adaptive_loss_weight: str = None
         | 
| 127 | 
            -
                    use_ddp: bool=EXPECTED
         | 
| 128 | 
            -
                   
         | 
| 129 | 
            -
                   
         | 
| 130 | 
            -
                @dataclass(frozen=True)
         | 
| 131 | 
            -
                class Data(Base):
         | 
| 132 | 
            -
                    dataset: Dataset = EXPECTED
         | 
| 133 | 
            -
                    dataloader: DataLoader  = EXPECTED
         | 
| 134 | 
            -
                    iterator: any = EXPECTED
         | 
| 135 | 
            -
                    sampler: DistributedSampler = EXPECTED
         | 
| 136 | 
            -
             | 
| 137 | 
            -
                @dataclass(frozen=True)
         | 
| 138 | 
            -
                class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
         | 
| 139 | 
            -
                    effnet: nn.Module = EXPECTED
         | 
| 140 | 
            -
                    previewer: nn.Module = EXPECTED
         | 
| 141 | 
            -
                    train_norm: nn.Module = EXPECTED
         | 
| 142 | 
            -
                   
         | 
| 143 | 
            -
             | 
| 144 | 
            -
                @dataclass(frozen=True)
         | 
| 145 | 
            -
                class Schedulers(WarpCore.Schedulers):
         | 
| 146 | 
            -
                    generator: any = None
         | 
| 147 | 
            -
             | 
| 148 | 
            -
                @dataclass(frozen=True)
         | 
| 149 | 
            -
                class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
         | 
| 150 | 
            -
                    gdf: GDF = EXPECTED
         | 
| 151 | 
            -
                    sampling_configs: dict = EXPECTED
         | 
| 152 | 
            -
                    effnet_preprocess: torchvision.transforms.Compose = EXPECTED
         | 
| 153 | 
            -
             | 
| 154 | 
            -
                info: TrainingCore.Info
         | 
| 155 | 
            -
                config: Config
         | 
| 156 | 
            -
             | 
| 157 | 
            -
                def setup_extras_pre(self) -> Extras:
         | 
| 158 | 
            -
                    gdf = GDF(
         | 
| 159 | 
            -
                        schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
         | 
| 160 | 
            -
                        input_scaler=VPScaler(), target=EpsilonTarget(),
         | 
| 161 | 
            -
                        noise_cond=CosineTNoiseCond(),
         | 
| 162 | 
            -
                        loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
         | 
| 163 | 
            -
                    )
         | 
| 164 | 
            -
                    sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
         | 
| 165 | 
            -
             | 
| 166 | 
            -
                    if self.info.adaptive_loss is not None:
         | 
| 167 | 
            -
                        gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
         | 
| 168 | 
            -
                        gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
         | 
| 169 | 
            -
             | 
| 170 | 
            -
                    effnet_preprocess = torchvision.transforms.Compose([
         | 
| 171 | 
            -
                        torchvision.transforms.Normalize(
         | 
| 172 | 
            -
                            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
         | 
| 173 | 
            -
                        )
         | 
| 174 | 
            -
                    ])
         | 
| 175 | 
            -
             | 
| 176 | 
            -
                    clip_preprocess = torchvision.transforms.Compose([
         | 
| 177 | 
            -
                        torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
         | 
| 178 | 
            -
                        torchvision.transforms.CenterCrop(224),
         | 
| 179 | 
            -
                        torchvision.transforms.Normalize(
         | 
| 180 | 
            -
                            mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
         | 
| 181 | 
            -
                        )
         | 
| 182 | 
            -
                    ])
         | 
| 183 | 
            -
             | 
| 184 | 
            -
                    if self.config.training:
         | 
| 185 | 
            -
                        transforms = torchvision.transforms.Compose([
         | 
| 186 | 
            -
                            torchvision.transforms.ToTensor(),
         | 
| 187 | 
            -
                            torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
         | 
| 188 | 
            -
                            SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
         | 
| 189 | 
            -
                        ])
         | 
| 190 | 
            -
                    else:
         | 
| 191 | 
            -
                        transforms = None
         | 
| 192 | 
            -
             | 
| 193 | 
            -
                    return self.Extras(
         | 
| 194 | 
            -
                        gdf=gdf,
         | 
| 195 | 
            -
                        sampling_configs=sampling_configs,
         | 
| 196 | 
            -
                        transforms=transforms,
         | 
| 197 | 
            -
                        effnet_preprocess=effnet_preprocess,
         | 
| 198 | 
            -
                        clip_preprocess=clip_preprocess
         | 
| 199 | 
            -
                    )
         | 
| 200 | 
            -
             | 
| 201 | 
            -
                def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
         | 
| 202 | 
            -
                                   eval_image_embeds=False, return_fields=None):
         | 
| 203 | 
            -
                    conditions = super().get_conditions(
         | 
| 204 | 
            -
                        batch, models, extras, is_eval, is_unconditional,
         | 
| 205 | 
            -
                        eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
         | 
| 206 | 
            -
                    )
         | 
| 207 | 
            -
                    return conditions
         | 
| 208 | 
            -
             | 
| 209 | 
            -
                def setup_models(self, extras: Extras) -> Models:   # configure model
         | 
| 210 | 
            -
             | 
| 211 | 
            -
                    dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
         | 
| 212 | 
            -
             | 
| 213 | 
            -
                    # EfficientNet encoderin
         | 
| 214 | 
            -
                    effnet = EfficientNetEncoder()
         | 
| 215 | 
            -
                    effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
         | 
| 216 | 
            -
                    effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
         | 
| 217 | 
            -
                    effnet.eval().requires_grad_(False).to(self.device)
         | 
| 218 | 
            -
                    del effnet_checkpoint
         | 
| 219 | 
            -
             | 
| 220 | 
            -
                    # Previewer
         | 
| 221 | 
            -
                    previewer = Previewer()
         | 
| 222 | 
            -
                    previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
         | 
| 223 | 
            -
                    previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
         | 
| 224 | 
            -
                    previewer.eval().requires_grad_(False).to(self.device)
         | 
| 225 | 
            -
                    del previewer_checkpoint
         | 
| 226 | 
            -
             | 
| 227 | 
            -
                    @contextmanager
         | 
| 228 | 
            -
                    def dummy_context():
         | 
| 229 | 
            -
                        yield None
         | 
| 230 | 
            -
             | 
| 231 | 
            -
                    loading_context = dummy_context if self.config.training else init_empty_weights
         | 
| 232 | 
            -
             | 
| 233 | 
            -
                    # Diffusion models
         | 
| 234 | 
            -
                    with loading_context():
         | 
| 235 | 
            -
                        generator_ema = None
         | 
| 236 | 
            -
                        if self.config.model_version == '3.6B':
         | 
| 237 | 
            -
                            generator = StageC()
         | 
| 238 | 
            -
                            if self.config.ema_start_iters is not None:  # default setting
         | 
| 239 | 
            -
                                generator_ema = StageC()
         | 
| 240 | 
            -
                        elif self.config.model_version == '1B':
         | 
| 241 | 
            -
                            print('in line 155 1b light model', self.config.model_version )
         | 
| 242 | 
            -
                            generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
         | 
| 243 | 
            -
                            
         | 
| 244 | 
            -
                            if self.config.ema_start_iters is not None and self.config.training:
         | 
| 245 | 
            -
                                generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
         | 
| 246 | 
            -
                        else:
         | 
| 247 | 
            -
                            raise ValueError(f"Unknown model version {self.config.model_version}")
         | 
| 248 | 
            -
             | 
| 249 | 
            -
                    
         | 
| 250 | 
            -
                 
         | 
| 251 | 
            -
                    if loading_context is dummy_context:
         | 
| 252 | 
            -
                        generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
         | 
| 253 | 
            -
                    else:
         | 
| 254 | 
            -
                        for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
         | 
| 255 | 
            -
                                set_module_tensor_to_device(generator, param_name, "cpu", value=param)
         | 
| 256 | 
            -
             | 
| 257 | 
            -
                    generator._init_extra_parameter()
         | 
| 258 | 
            -
                    generator = generator.to(torch.bfloat16).to(self.device)
         | 
| 259 | 
            -
                   
         | 
| 260 | 
            -
                    
         | 
| 261 | 
            -
                    train_norm = nn.ModuleList()
         | 
| 262 | 
            -
                    cnt_norm = 0
         | 
| 263 | 
            -
                    for mm in generator.modules():
         | 
| 264 | 
            -
                        if isinstance(mm,  GlobalResponseNorm):
         | 
| 265 | 
            -
                           
         | 
| 266 | 
            -
                            train_norm.append(Null_Model())
         | 
| 267 | 
            -
                            cnt_norm += 1
         | 
| 268 | 
            -
                         
         | 
| 269 | 
            -
                    train_norm.append(generator.agg_net)
         | 
| 270 | 
            -
                    train_norm.append(generator.agg_net_up)      
         | 
| 271 | 
            -
                    total = sum([ param.nelement()  for param in train_norm.parameters()])
         | 
| 272 | 
            -
                    print('Trainable parameter', total / 1048576)
         | 
| 273 | 
            -
                    
         | 
| 274 | 
            -
                    if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
         | 
| 275 | 
            -
                        sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
         | 
| 276 | 
            -
                        collect_sd = {}
         | 
| 277 | 
            -
                        for k, v in sdd.items():
         | 
| 278 | 
            -
                            collect_sd[k[7:]] = v
         | 
| 279 | 
            -
                        train_norm.load_state_dict(collect_sd, strict=True)
         | 
| 280 | 
            -
                    
         | 
| 281 | 
            -
                   
         | 
| 282 | 
            -
                    train_norm.to(self.device).train().requires_grad_(True)
         | 
| 283 | 
            -
                     | 
| 284 | 
            -
                     | 
| 285 | 
            -
             | 
| 286 | 
            -
                        
         | 
| 287 | 
            -
                        generator_ema. | 
| 288 | 
            -
             | 
| 289 | 
            -
             | 
| 290 | 
            -
             | 
| 291 | 
            -
                         | 
| 292 | 
            -
             | 
| 293 | 
            -
                           | 
| 294 | 
            -
             | 
| 295 | 
            -
             | 
| 296 | 
            -
             | 
| 297 | 
            -
             | 
| 298 | 
            -
             | 
| 299 | 
            -
             | 
| 300 | 
            -
             | 
| 301 | 
            -
                    
         | 
| 302 | 
            -
             | 
| 303 | 
            -
             | 
| 304 | 
            -
                    
         | 
| 305 | 
            -
                    
         | 
| 306 | 
            -
             | 
| 307 | 
            -
             | 
| 308 | 
            -
                         | 
| 309 | 
            -
             | 
| 310 | 
            -
                     | 
| 311 | 
            -
                     | 
| 312 | 
            -
                     | 
| 313 | 
            -
                     | 
| 314 | 
            -
                    
         | 
| 315 | 
            -
             | 
| 316 | 
            -
                         | 
| 317 | 
            -
             | 
| 318 | 
            -
             | 
| 319 | 
            -
             | 
| 320 | 
            -
             | 
| 321 | 
            -
             | 
| 322 | 
            -
             
         | 
| 323 | 
            -
                    params  | 
| 324 | 
            -
             | 
| 325 | 
            -
             | 
| 326 | 
            -
             | 
| 327 | 
            -
             | 
| 328 | 
            -
             | 
| 329 | 
            -
             | 
| 330 | 
            -
             | 
| 331 | 
            -
             | 
| 332 | 
            -
                         | 
| 333 | 
            -
             | 
| 334 | 
            -
             | 
| 335 | 
            -
             | 
| 336 | 
            -
                         | 
| 337 | 
            -
             | 
| 338 | 
            -
             | 
| 339 | 
            -
             | 
| 340 | 
            -
             | 
| 341 | 
            -
             | 
| 342 | 
            -
             | 
| 343 | 
            -
                     | 
| 344 | 
            -
             | 
| 345 | 
            -
             | 
| 346 | 
            -
             | 
| 347 | 
            -
             | 
| 348 | 
            -
             | 
| 349 | 
            -
                    scheduler | 
| 350 | 
            -
                     | 
| 351 | 
            -
             | 
| 352 | 
            -
             | 
| 353 | 
            -
             | 
| 354 | 
            -
                     | 
| 355 | 
            -
                     | 
| 356 | 
            -
             | 
| 357 | 
            -
             | 
| 358 | 
            -
             | 
| 359 | 
            -
             | 
| 360 | 
            -
                     | 
| 361 | 
            -
             | 
| 362 | 
            -
             | 
| 363 | 
            -
                     | 
| 364 | 
            -
             | 
| 365 | 
            -
                         | 
| 366 | 
            -
                         | 
| 367 | 
            -
             | 
| 368 | 
            -
                     | 
| 369 | 
            -
             | 
| 370 | 
            -
             | 
| 371 | 
            -
             | 
| 372 | 
            -
             | 
| 373 | 
            -
                         | 
| 374 | 
            -
             | 
| 375 | 
            -
                                                         | 
| 376 | 
            -
             | 
| 377 | 
            -
             | 
| 378 | 
            -
             | 
| 379 | 
            -
             | 
| 380 | 
            -
             | 
| 381 | 
            -
             | 
| 382 | 
            -
             | 
| 383 | 
            -
             | 
| 384 | 
            -
             | 
| 385 | 
            -
             | 
| 386 | 
            -
             | 
| 387 | 
            -
             | 
| 388 | 
            -
             | 
| 389 | 
            -
                         | 
| 390 | 
            -
                         | 
| 391 | 
            -
             | 
| 392 | 
            -
             | 
| 393 | 
            -
                        self. | 
| 394 | 
            -
                        self. | 
| 395 | 
            -
                        self. | 
| 396 | 
            -
             | 
| 397 | 
            -
             | 
| 398 | 
            -
                        os.environ[' | 
| 399 | 
            -
                         | 
| 400 | 
            -
                         | 
| 401 | 
            -
             | 
| 402 | 
            -
                             | 
| 403 | 
            -
                             | 
| 404 | 
            -
             | 
| 405 | 
            -
                        )
         | 
| 406 | 
            -
             | 
| 407 | 
            -
             | 
| 408 | 
            -
                        self. | 
| 409 | 
            -
                        self. | 
| 410 | 
            -
                        self. | 
| 411 | 
            -
                         | 
| 412 | 
            -
             | 
| 413 | 
            -
                 | 
| 414 | 
            -
             | 
| 415 | 
            -
                     | 
| 416 | 
            -
             | 
| 417 | 
            -
             | 
| 418 | 
            -
                     | 
| 419 | 
            -
                     | 
| 420 | 
            -
                     | 
| 421 | 
            -
                    shape_lr  | 
| 422 | 
            -
                     | 
| 423 | 
            -
             | 
| 424 | 
            -
                         | 
| 425 | 
            -
                        
         | 
| 426 | 
            -
                         | 
| 427 | 
            -
                         | 
| 428 | 
            -
                        
         | 
| 429 | 
            -
                         | 
| 430 | 
            -
             | 
| 431 | 
            -
             | 
| 432 | 
            -
             | 
| 433 | 
            -
                         | 
| 434 | 
            -
             | 
| 435 | 
            -
             | 
| 436 | 
            -
             | 
| 437 | 
            -
             | 
| 438 | 
            -
                        
         | 
| 439 | 
            -
                        
         | 
| 440 | 
            -
                         | 
| 441 | 
            -
             | 
| 442 | 
            -
             | 
| 443 | 
            -
             | 
| 444 | 
            -
             | 
| 445 | 
            -
             | 
| 446 | 
            -
             | 
| 447 | 
            -
             | 
| 448 | 
            -
             | 
| 449 | 
            -
             | 
| 450 | 
            -
             | 
| 451 | 
            -
             | 
| 452 | 
            -
             | 
| 453 | 
            -
                    
         | 
| 454 | 
            -
             | 
| 455 | 
            -
             | 
| 456 | 
            -
                         | 
| 457 | 
            -
                         | 
| 458 | 
            -
                        
         | 
| 459 | 
            -
             | 
| 460 | 
            -
             | 
| 461 | 
            -
                         | 
| 462 | 
            -
             | 
| 463 | 
            -
             | 
| 464 | 
            -
             | 
| 465 | 
            -
                         | 
| 466 | 
            -
             | 
| 467 | 
            -
             | 
| 468 | 
            -
             | 
| 469 | 
            -
             | 
| 470 | 
            -
             | 
| 471 | 
            -
             | 
| 472 | 
            -
             | 
| 473 | 
            -
             | 
| 474 | 
            -
             | 
| 475 | 
            -
             | 
| 476 | 
            -
             | 
| 477 | 
            -
             | 
| 478 | 
            -
                    
         | 
| 479 | 
            -
             | 
| 480 | 
            -
             | 
| 481 | 
            -
             | 
| 482 | 
            -
             | 
| 483 | 
            -
                    
         | 
| 484 | 
            -
                     | 
| 485 | 
            -
             | 
| 486 | 
            -
                       | 
| 487 | 
            -
             | 
| 488 | 
            -
             | 
| 489 | 
            -
             | 
| 490 | 
            -
             | 
| 491 | 
            -
             | 
| 492 | 
            -
             | 
| 493 | 
            -
             | 
| 494 | 
            -
             | 
| 495 | 
            -
                    self. | 
| 496 | 
            -
                    self. | 
| 497 | 
            -
                    self. | 
| 498 | 
            -
                     | 
| 499 | 
            -
             | 
| 500 | 
            -
             | 
| 501 | 
            -
             | 
| 502 | 
            -
             | 
| 503 | 
            -
                    
         | 
| 504 | 
            -
             | 
| 505 | 
            -
                        torch.backends. | 
| 506 | 
            -
             | 
| 507 | 
            -
             | 
| 508 | 
            -
             | 
| 509 | 
            -
                        print()
         | 
| 510 | 
            -
                        print( | 
| 511 | 
            -
                        print( | 
| 512 | 
            -
                        print( | 
| 513 | 
            -
                        print()
         | 
| 514 | 
            -
                        print( | 
| 515 | 
            -
                        print( | 
| 516 | 
            -
                        print( | 
| 517 | 
            -
             | 
| 518 | 
            -
                    
         | 
| 519 | 
            -
                     | 
| 520 | 
            -
                    extras  | 
| 521 | 
            -
             | 
| 522 | 
            -
             | 
| 523 | 
            -
             | 
| 524 | 
            -
             | 
| 525 | 
            -
                    data  | 
| 526 | 
            -
                     | 
| 527 | 
            -
             | 
| 528 | 
            -
                        print( | 
| 529 | 
            -
                        print( | 
| 530 | 
            -
                        print( | 
| 531 | 
            -
             | 
| 532 | 
            -
             | 
| 533 | 
            -
                    models  | 
| 534 | 
            -
                     | 
| 535 | 
            -
             | 
| 536 | 
            -
                        print( | 
| 537 | 
            -
             | 
| 538 | 
            -
             | 
| 539 | 
            -
                         | 
| 540 | 
            -
                        print( | 
| 541 | 
            -
             | 
| 542 | 
            -
             | 
| 543 | 
            -
             | 
| 544 | 
            -
             | 
| 545 | 
            -
                    optimizers  | 
| 546 | 
            -
                     | 
| 547 | 
            -
             | 
| 548 | 
            -
                        print( | 
| 549 | 
            -
                        print( | 
| 550 | 
            -
                        print( | 
| 551 | 
            -
             | 
| 552 | 
            -
             | 
| 553 | 
            -
                    schedulers  | 
| 554 | 
            -
                     | 
| 555 | 
            -
             | 
| 556 | 
            -
                        print( | 
| 557 | 
            -
                        print( | 
| 558 | 
            -
                        print( | 
| 559 | 
            -
             | 
| 560 | 
            -
             | 
| 561 | 
            -
                    post_extras  | 
| 562 | 
            -
                     | 
| 563 | 
            -
                     | 
| 564 | 
            -
             | 
| 565 | 
            -
                        print(" | 
| 566 | 
            -
                        print( | 
| 567 | 
            -
                        print( | 
| 568 | 
            -
             | 
| 569 | 
            -
             | 
| 570 | 
            -
             | 
| 571 | 
            -
                     | 
| 572 | 
            -
             | 
| 573 | 
            -
             | 
| 574 | 
            -
             | 
| 575 | 
            -
             | 
| 576 | 
            -
             | 
| 577 | 
            -
                         | 
| 578 | 
            -
             | 
| 579 | 
            -
             | 
| 580 | 
            -
                        print()
         | 
| 581 | 
            -
                        print( | 
| 582 | 
            -
                        print()
         | 
| 583 | 
            -
             | 
| 584 | 
            -
             | 
| 585 | 
            -
             | 
| 586 | 
            -
             | 
| 587 | 
            -
             | 
| 588 | 
            -
             | 
| 589 | 
            -
                     | 
| 590 | 
            -
                     | 
| 591 | 
            -
             | 
| 592 | 
            -
             | 
| 593 | 
            -
             | 
| 594 | 
            -
             | 
| 595 | 
            -
             | 
| 596 | 
            -
             | 
| 597 | 
            -
                    
         | 
| 598 | 
            -
             | 
| 599 | 
            -
             | 
| 600 | 
            -
                     | 
| 601 | 
            -
                     | 
| 602 | 
            -
                     | 
| 603 | 
            -
             | 
| 604 | 
            -
                       | 
| 605 | 
            -
             | 
| 606 | 
            -
                        
         | 
| 607 | 
            -
             | 
| 608 | 
            -
             | 
| 609 | 
            -
                           | 
| 610 | 
            -
                           | 
| 611 | 
            -
             | 
| 612 | 
            -
                                     | 
| 613 | 
            -
             | 
| 614 | 
            -
             | 
| 615 | 
            -
             | 
| 616 | 
            -
                           | 
| 617 | 
            -
             | 
| 618 | 
            -
             | 
| 619 | 
            -
                           | 
| 620 | 
            -
             | 
| 621 | 
            -
             | 
| 622 | 
            -
                           | 
| 623 | 
            -
             | 
| 624 | 
            -
                                   | 
| 625 | 
            -
             | 
| 626 | 
            -
             | 
| 627 | 
            -
             | 
| 628 | 
            -
             | 
| 629 | 
            -
                                  ' | 
| 630 | 
            -
                                  ' | 
| 631 | 
            -
                                  ' | 
| 632 | 
            -
                                  ' | 
| 633 | 
            -
                                  ' | 
| 634 | 
            -
                                  ' | 
| 635 | 
            -
             | 
| 636 | 
            -
                               | 
| 637 | 
            -
             | 
| 638 | 
            -
             | 
| 639 | 
            -
             | 
| 640 | 
            -
                              
         | 
| 641 | 
            -
             | 
| 642 | 
            -
                          
         | 
| 643 | 
            -
             | 
| 644 | 
            -
             | 
| 645 | 
            -
                               | 
| 646 | 
            -
             | 
| 647 | 
            -
             | 
| 648 | 
            -
                                       | 
| 649 | 
            -
             | 
| 650 | 
            -
             | 
| 651 | 
            -
             | 
| 652 | 
            -
             | 
| 653 | 
            -
             | 
| 654 | 
            -
                                          ' | 
| 655 | 
            -
             | 
| 656 | 
            -
             | 
| 657 | 
            -
             | 
| 658 | 
            -
                                  
         | 
| 659 | 
            -
                                  
         | 
| 660 | 
            -
             | 
| 661 | 
            -
                                       | 
| 662 | 
            -
                                       | 
| 663 | 
            -
             | 
| 664 | 
            -
             | 
| 665 | 
            -
             | 
| 666 | 
            -
             | 
| 667 | 
            -
             | 
| 668 | 
            -
             | 
| 669 | 
            -
             | 
| 670 | 
            -
                              
         | 
| 671 | 
            -
             | 
| 672 | 
            -
                                 
         | 
| 673 | 
            -
             | 
| 674 | 
            -
             | 
| 675 | 
            -
             | 
| 676 | 
            -
             | 
| 677 | 
            -
                         | 
| 678 | 
            -
             | 
| 679 | 
            -
             | 
| 680 | 
            -
                   
         | 
| 681 | 
            -
             | 
| 682 | 
            -
                    models. | 
| 683 | 
            -
                     | 
| 684 | 
            -
             | 
| 685 | 
            -
                         | 
| 686 | 
            -
             | 
| 687 | 
            -
             | 
| 688 | 
            -
                         | 
| 689 | 
            -
                         | 
| 690 | 
            -
             | 
| 691 | 
            -
             | 
| 692 | 
            -
                         | 
| 693 | 
            -
             | 
| 694 | 
            -
             | 
| 695 | 
            -
                        
         | 
| 696 | 
            -
             | 
| 697 | 
            -
                            
         | 
| 698 | 
            -
             | 
| 699 | 
            -
                                
         | 
| 700 | 
            -
             | 
| 701 | 
            -
                                     | 
| 702 | 
            -
                                     | 
| 703 | 
            -
             | 
| 704 | 
            -
             | 
| 705 | 
            -
             | 
| 706 | 
            -
             | 
| 707 | 
            -
                        
         | 
| 708 | 
            -
                        
         | 
| 709 | 
            -
             | 
| 710 | 
            -
                             | 
| 711 | 
            -
             | 
| 712 | 
            -
             | 
| 713 | 
            -
                            
         | 
| 714 | 
            -
             | 
| 715 | 
            -
             | 
| 716 | 
            -
             | 
| 717 | 
            -
             | 
| 718 | 
            -
             | 
| 719 | 
            -
             | 
| 720 | 
            -
                            
         | 
| 721 | 
            -
             | 
| 722 | 
            -
             | 
| 723 | 
            -
             | 
| 724 | 
            -
                            images  | 
| 725 | 
            -
             | 
| 726 | 
            -
                                 | 
| 727 | 
            -
             | 
| 728 | 
            -
             | 
| 729 | 
            -
             | 
| 730 | 
            -
                                torch.cat([i for i in  | 
| 731 | 
            -
                                torch.cat([i for i in  | 
| 732 | 
            -
             | 
| 733 | 
            -
                             | 
| 734 | 
            -
                            
         | 
| 735 | 
            -
             | 
| 736 | 
            -
                                torch.cat([i for i in  | 
| 737 | 
            -
                                torch.cat([i for i in  | 
| 738 | 
            -
             | 
| 739 | 
            -
             | 
| 740 | 
            -
             | 
| 741 | 
            -
                            torchvision.utils.save_image( | 
| 742 | 
            -
             | 
| 743 | 
            -
             | 
| 744 | 
            -
             | 
| 745 | 
            -
                        models. | 
| 746 | 
            -
                         | 
| 747 | 
            -
             | 
| 748 | 
            -
                
         | 
| 749 | 
            -
                
         | 
| 750 | 
            -
                
         | 
| 751 | 
            -
             | 
| 752 | 
            -
             | 
| 753 | 
            -
             | 
| 754 | 
            -
                     | 
| 755 | 
            -
                    
         | 
| 756 | 
            -
             | 
| 757 | 
            -
             | 
| 758 | 
            -
             | 
| 759 | 
            -
                             | 
| 760 | 
            -
             | 
| 761 | 
            -
             | 
| 762 | 
            -
             | 
| 763 | 
            -
             | 
| 764 | 
            -
             | 
| 765 | 
            -
                                     | 
| 766 | 
            -
                                     | 
| 767 | 
            -
             | 
| 768 | 
            -
             | 
| 769 | 
            -
             | 
| 770 | 
            -
             | 
| 771 | 
            -
                                    
         | 
| 772 | 
            -
             | 
| 773 | 
            -
                                         | 
| 774 | 
            -
                                         | 
| 775 | 
            -
             | 
| 776 | 
            -
             | 
| 777 | 
            -
             | 
| 778 | 
            -
             | 
| 779 | 
            -
                                     | 
| 780 | 
            -
             | 
| 781 | 
            -
             | 
| 782 | 
            -
             | 
| 783 | 
            -
             | 
| 784 | 
            -
             | 
| 785 | 
            -
             | 
| 786 | 
            -
             | 
| 787 | 
            -
             | 
| 788 | 
            -
                 | 
| 789 | 
            -
             | 
| 790 | 
            -
             | 
| 791 | 
            -
                 | 
| 792 | 
            -
             | 
| 793 | 
            -
             | 
| 794 | 
            -
             | 
| 795 | 
            -
                 | 
| 796 | 
            -
                # os.environ[" | 
| 797 | 
            -
                # | 
| 798 | 
            -
                # | 
| 799 | 
            -
             | 
| 800 | 
            -
            # | 
| 801 | 
            -
             | 
| 802 | 
            -
             | 
| 803 | 
            -
             | 
| 804 | 
            -
                     | 
| 805 | 
            -
             | 
| 806 | 
            -
             | 
| 807 | 
            -
                    main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )
         | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
            import torchvision
         | 
| 5 | 
            +
            from torch import nn, optim
         | 
| 6 | 
            +
            from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPVisionModelWithProjection
         | 
| 7 | 
            +
            from warmup_scheduler import GradualWarmupScheduler
         | 
| 8 | 
            +
            import torch.multiprocessing as mp
         | 
| 9 | 
            +
            import numpy as np
         | 
| 10 | 
            +
            import os
         | 
| 11 | 
            +
            import sys
         | 
| 12 | 
            +
            sys.path.append(os.path.abspath('./'))
         | 
| 13 | 
            +
            from dataclasses import dataclass
         | 
| 14 | 
            +
            from torch.distributed import init_process_group, destroy_process_group, barrier
         | 
| 15 | 
            +
            from gdf import GDF_dual_fixlrt as GDF
         | 
| 16 | 
            +
            from gdf import EpsilonTarget, CosineSchedule
         | 
| 17 | 
            +
            from gdf import VPScaler, CosineTNoiseCond, DDPMSampler, P2LossWeight, AdaptiveLossWeight
         | 
| 18 | 
            +
            from torchtools.transforms import SmartCrop
         | 
| 19 | 
            +
            from fractions import Fraction
         | 
| 20 | 
            +
            from modules.effnet import EfficientNetEncoder
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            from modules.model_4stage_lite import StageC, ResBlock, AttnBlock, TimestepBlock, FeedForwardBlock
         | 
| 23 | 
            +
            from modules.previewer import Previewer
         | 
| 24 | 
            +
            from core.data import Bucketeer
         | 
| 25 | 
            +
            from train.base import DataCore, TrainingCore
         | 
| 26 | 
            +
            from tqdm import tqdm
         | 
| 27 | 
            +
            from core import WarpCore
         | 
| 28 | 
            +
            from core.utils import EXPECTED, EXPECTED_TRAIN, load_or_fail
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            from accelerate import init_empty_weights
         | 
| 31 | 
            +
            from accelerate.utils import set_module_tensor_to_device
         | 
| 32 | 
            +
            from contextlib import contextmanager
         | 
| 33 | 
            +
            from train.dist_core import *
         | 
| 34 | 
            +
            import glob
         | 
| 35 | 
            +
            from torch.utils.data import DataLoader, Dataset
         | 
| 36 | 
            +
            from torch.nn.parallel import DistributedDataParallel as DDP
         | 
| 37 | 
            +
            from torch.utils.data.distributed import DistributedSampler
         | 
| 38 | 
            +
            from PIL import Image
         | 
| 39 | 
            +
            from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary
         | 
| 40 | 
            +
            from core.utils import Base
         | 
| 41 | 
            +
            from modules.common_ckpt import LayerNorm2d, GlobalResponseNorm
         | 
| 42 | 
            +
            import torch.nn.functional as F
         | 
| 43 | 
            +
            import functools
         | 
| 44 | 
            +
            import math
         | 
| 45 | 
            +
            import copy
         | 
| 46 | 
            +
            import random
         | 
| 47 | 
            +
            from modules.lora import apply_lora, apply_retoken, LoRA, ReToken
         | 
| 48 | 
            +
            Image.MAX_IMAGE_PIXELS = None
         | 
| 49 | 
            +
            torch.manual_seed(23)
         | 
| 50 | 
            +
            random.seed(23)
         | 
| 51 | 
            +
            np.random.seed(23)
         | 
| 52 | 
            +
            #7978026
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            class Null_Model(torch.nn.Module):
         | 
| 55 | 
            +
                def __init__(self):
         | 
| 56 | 
            +
                    super().__init__()
         | 
| 57 | 
            +
                def forward(self, x):
         | 
| 58 | 
            +
                    pass
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def identity(x):
         | 
| 64 | 
            +
                if isinstance(x, bytes):
         | 
| 65 | 
            +
                    x = x.decode('utf-8')
         | 
| 66 | 
            +
                return x
         | 
| 67 | 
            +
            def check_nan_inmodel(model, meta=''):
         | 
| 68 | 
            +
                    for name, param in model.named_parameters():
         | 
| 69 | 
            +
                        if torch.isnan(param).any():
         | 
| 70 | 
            +
                            print(f"nan detected in {name}", meta)
         | 
| 71 | 
            +
                            return True
         | 
| 72 | 
            +
                    print('no nan', meta)
         | 
| 73 | 
            +
                    return False  
         | 
| 74 | 
            +
            class mydist_dataset(Dataset):
         | 
| 75 | 
            +
                def __init__(self, rootpath, img_processor=None):
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    self.img_pathlist = glob.glob(os.path.join(rootpath, '*', '*.jpg'))
         | 
| 78 | 
            +
                    self.img_processor = img_processor
         | 
| 79 | 
            +
                    self.length = len( self.img_pathlist)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                  
         | 
| 82 | 
            +
                  
         | 
| 83 | 
            +
                def __getitem__(self, idx):
         | 
| 84 | 
            +
                    
         | 
| 85 | 
            +
                    imgpath = self.img_pathlist[idx]
         | 
| 86 | 
            +
                    json_file = imgpath.replace('.jpg', '.json') 
         | 
| 87 | 
            +
                   
         | 
| 88 | 
            +
                    with open(json_file, 'r') as file:
         | 
| 89 | 
            +
                        info = json.load(file)
         | 
| 90 | 
            +
                    txt = info['caption']
         | 
| 91 | 
            +
                    if txt is None:
         | 
| 92 | 
            +
                        txt = ' ' 
         | 
| 93 | 
            +
                    try:  
         | 
| 94 | 
            +
                      img = Image.open(imgpath).convert('RGB')
         | 
| 95 | 
            +
                      w, h = img.size
         | 
| 96 | 
            +
                      if self.img_processor is not None:
         | 
| 97 | 
            +
                        img = self.img_processor(img)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    except:
         | 
| 100 | 
            +
                      print('exception', imgpath)
         | 
| 101 | 
            +
                      return self.__getitem__(random.randint(0, self.length -1 ) )
         | 
| 102 | 
            +
                    return dict(captions=txt, images=img)
         | 
| 103 | 
            +
                def __len__(self):
         | 
| 104 | 
            +
                    return self.length
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            class WurstCore(TrainingCore, DataCore, WarpCore):
         | 
| 107 | 
            +
                @dataclass(frozen=True)
         | 
| 108 | 
            +
                class Config(TrainingCore.Config, DataCore.Config, WarpCore.Config):
         | 
| 109 | 
            +
                    # TRAINING PARAMS
         | 
| 110 | 
            +
                    lr: float = EXPECTED_TRAIN
         | 
| 111 | 
            +
                    warmup_updates: int = EXPECTED_TRAIN
         | 
| 112 | 
            +
                    dtype: str = None
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    # MODEL VERSION
         | 
| 115 | 
            +
                    model_version: str = EXPECTED  # 3.6B or 1B
         | 
| 116 | 
            +
                    clip_image_model_name: str = 'openai/clip-vit-large-patch14'
         | 
| 117 | 
            +
                    clip_text_model_name: str = 'laion/CLIP-ViT-bigG-14-laion2B-39B-b160k'
         | 
| 118 | 
            +
                 
         | 
| 119 | 
            +
                    # CHECKPOINT PATHS
         | 
| 120 | 
            +
                    effnet_checkpoint_path: str = EXPECTED
         | 
| 121 | 
            +
                    previewer_checkpoint_path: str = EXPECTED
         | 
| 122 | 
            +
                   
         | 
| 123 | 
            +
                    generator_checkpoint_path: str = None
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    # gdf customization
         | 
| 126 | 
            +
                    adaptive_loss_weight: str = None
         | 
| 127 | 
            +
                    use_ddp: bool=EXPECTED
         | 
| 128 | 
            +
                   
         | 
| 129 | 
            +
                   
         | 
| 130 | 
            +
                @dataclass(frozen=True)
         | 
| 131 | 
            +
                class Data(Base):
         | 
| 132 | 
            +
                    dataset: Dataset = EXPECTED
         | 
| 133 | 
            +
                    dataloader: DataLoader  = EXPECTED
         | 
| 134 | 
            +
                    iterator: any = EXPECTED
         | 
| 135 | 
            +
                    sampler: DistributedSampler = EXPECTED
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                @dataclass(frozen=True)
         | 
| 138 | 
            +
                class Models(TrainingCore.Models, DataCore.Models, WarpCore.Models):
         | 
| 139 | 
            +
                    effnet: nn.Module = EXPECTED
         | 
| 140 | 
            +
                    previewer: nn.Module = EXPECTED
         | 
| 141 | 
            +
                    train_norm: nn.Module = EXPECTED
         | 
| 142 | 
            +
                   
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                @dataclass(frozen=True)
         | 
| 145 | 
            +
                class Schedulers(WarpCore.Schedulers):
         | 
| 146 | 
            +
                    generator: any = None
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                @dataclass(frozen=True)
         | 
| 149 | 
            +
                class Extras(TrainingCore.Extras, DataCore.Extras, WarpCore.Extras):
         | 
| 150 | 
            +
                    gdf: GDF = EXPECTED
         | 
| 151 | 
            +
                    sampling_configs: dict = EXPECTED
         | 
| 152 | 
            +
                    effnet_preprocess: torchvision.transforms.Compose = EXPECTED
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                info: TrainingCore.Info
         | 
| 155 | 
            +
                config: Config
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def setup_extras_pre(self) -> Extras:
         | 
| 158 | 
            +
                    gdf = GDF(
         | 
| 159 | 
            +
                        schedule=CosineSchedule(clamp_range=[0.0001, 0.9999]),
         | 
| 160 | 
            +
                        input_scaler=VPScaler(), target=EpsilonTarget(),
         | 
| 161 | 
            +
                        noise_cond=CosineTNoiseCond(),
         | 
| 162 | 
            +
                        loss_weight=AdaptiveLossWeight() if self.config.adaptive_loss_weight is True else P2LossWeight(),
         | 
| 163 | 
            +
                    )
         | 
| 164 | 
            +
                    sampling_configs = {"cfg": 5, "sampler": DDPMSampler(gdf), "shift": 1, "timesteps": 20}
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                    if self.info.adaptive_loss is not None:
         | 
| 167 | 
            +
                        gdf.loss_weight.bucket_ranges = torch.tensor(self.info.adaptive_loss['bucket_ranges'])
         | 
| 168 | 
            +
                        gdf.loss_weight.bucket_losses = torch.tensor(self.info.adaptive_loss['bucket_losses'])
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                    effnet_preprocess = torchvision.transforms.Compose([
         | 
| 171 | 
            +
                        torchvision.transforms.Normalize(
         | 
| 172 | 
            +
                            mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
         | 
| 173 | 
            +
                        )
         | 
| 174 | 
            +
                    ])
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                    clip_preprocess = torchvision.transforms.Compose([
         | 
| 177 | 
            +
                        torchvision.transforms.Resize(224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC),
         | 
| 178 | 
            +
                        torchvision.transforms.CenterCrop(224),
         | 
| 179 | 
            +
                        torchvision.transforms.Normalize(
         | 
| 180 | 
            +
                            mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)
         | 
| 181 | 
            +
                        )
         | 
| 182 | 
            +
                    ])
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    if self.config.training:
         | 
| 185 | 
            +
                        transforms = torchvision.transforms.Compose([
         | 
| 186 | 
            +
                            torchvision.transforms.ToTensor(),
         | 
| 187 | 
            +
                            torchvision.transforms.Resize(self.config.image_size[-1], interpolation=torchvision.transforms.InterpolationMode.BILINEAR, antialias=True),
         | 
| 188 | 
            +
                            SmartCrop(self.config.image_size, randomize_p=0.3, randomize_q=0.2)
         | 
| 189 | 
            +
                        ])
         | 
| 190 | 
            +
                    else:
         | 
| 191 | 
            +
                        transforms = None
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                    return self.Extras(
         | 
| 194 | 
            +
                        gdf=gdf,
         | 
| 195 | 
            +
                        sampling_configs=sampling_configs,
         | 
| 196 | 
            +
                        transforms=transforms,
         | 
| 197 | 
            +
                        effnet_preprocess=effnet_preprocess,
         | 
| 198 | 
            +
                        clip_preprocess=clip_preprocess
         | 
| 199 | 
            +
                    )
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False,
         | 
| 202 | 
            +
                                   eval_image_embeds=False, return_fields=None):
         | 
| 203 | 
            +
                    conditions = super().get_conditions(
         | 
| 204 | 
            +
                        batch, models, extras, is_eval, is_unconditional,
         | 
| 205 | 
            +
                        eval_image_embeds, return_fields=return_fields or ['clip_text', 'clip_text_pooled', 'clip_img']
         | 
| 206 | 
            +
                    )
         | 
| 207 | 
            +
                    return conditions
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                def setup_models(self, extras: Extras) -> Models:   # configure model
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    dtype = getattr(torch, self.config.dtype) if self.config.dtype else torch.bfloat16
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                    # EfficientNet encoderin
         | 
| 214 | 
            +
                    effnet = EfficientNetEncoder()
         | 
| 215 | 
            +
                    effnet_checkpoint = load_or_fail(self.config.effnet_checkpoint_path)
         | 
| 216 | 
            +
                    effnet.load_state_dict(effnet_checkpoint if 'state_dict' not in effnet_checkpoint else effnet_checkpoint['state_dict'])
         | 
| 217 | 
            +
                    effnet.eval().requires_grad_(False).to(self.device)
         | 
| 218 | 
            +
                    del effnet_checkpoint
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # Previewer
         | 
| 221 | 
            +
                    previewer = Previewer()
         | 
| 222 | 
            +
                    previewer_checkpoint = load_or_fail(self.config.previewer_checkpoint_path)
         | 
| 223 | 
            +
                    previewer.load_state_dict(previewer_checkpoint if 'state_dict' not in previewer_checkpoint else previewer_checkpoint['state_dict'])
         | 
| 224 | 
            +
                    previewer.eval().requires_grad_(False).to(self.device)
         | 
| 225 | 
            +
                    del previewer_checkpoint
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    @contextmanager
         | 
| 228 | 
            +
                    def dummy_context():
         | 
| 229 | 
            +
                        yield None
         | 
| 230 | 
            +
             | 
| 231 | 
            +
                    loading_context = dummy_context if self.config.training else init_empty_weights
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    # Diffusion models
         | 
| 234 | 
            +
                    with loading_context():
         | 
| 235 | 
            +
                        generator_ema = None
         | 
| 236 | 
            +
                        if self.config.model_version == '3.6B':
         | 
| 237 | 
            +
                            generator = StageC()
         | 
| 238 | 
            +
                            if self.config.ema_start_iters is not None:  # default setting
         | 
| 239 | 
            +
                                generator_ema = StageC()
         | 
| 240 | 
            +
                        elif self.config.model_version == '1B':
         | 
| 241 | 
            +
                            print('in line 155 1b light model', self.config.model_version )
         | 
| 242 | 
            +
                            generator = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
         | 
| 243 | 
            +
                            
         | 
| 244 | 
            +
                            if self.config.ema_start_iters is not None and self.config.training:
         | 
| 245 | 
            +
                                generator_ema = StageC(c_cond=1536, c_hidden=[1536, 1536], nhead=[24, 24], blocks=[[4, 12], [12, 4]])
         | 
| 246 | 
            +
                        else:
         | 
| 247 | 
            +
                            raise ValueError(f"Unknown model version {self.config.model_version}")
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                    
         | 
| 250 | 
            +
                 
         | 
| 251 | 
            +
                    if loading_context is dummy_context:
         | 
| 252 | 
            +
                        generator.load_state_dict( load_or_fail(self.config.generator_checkpoint_path))
         | 
| 253 | 
            +
                    else:
         | 
| 254 | 
            +
                        for param_name, param in load_or_fail(self.config.generator_checkpoint_path).items():
         | 
| 255 | 
            +
                                set_module_tensor_to_device(generator, param_name, "cpu", value=param)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                    generator._init_extra_parameter()
         | 
| 258 | 
            +
                    generator = generator.to(torch.bfloat16).to(self.device)
         | 
| 259 | 
            +
                   
         | 
| 260 | 
            +
                    
         | 
| 261 | 
            +
                    train_norm = nn.ModuleList()
         | 
| 262 | 
            +
                    cnt_norm = 0
         | 
| 263 | 
            +
                    for mm in generator.modules():
         | 
| 264 | 
            +
                        if isinstance(mm,  GlobalResponseNorm):
         | 
| 265 | 
            +
                           
         | 
| 266 | 
            +
                            train_norm.append(Null_Model())
         | 
| 267 | 
            +
                            cnt_norm += 1
         | 
| 268 | 
            +
                         
         | 
| 269 | 
            +
                    train_norm.append(generator.agg_net)
         | 
| 270 | 
            +
                    train_norm.append(generator.agg_net_up)      
         | 
| 271 | 
            +
                    total = sum([ param.nelement()  for param in train_norm.parameters()])
         | 
| 272 | 
            +
                    print('Trainable parameter', total / 1048576)
         | 
| 273 | 
            +
                    
         | 
| 274 | 
            +
                    if os.path.exists(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors')):
         | 
| 275 | 
            +
                        sdd = torch.load(os.path.join(self.config.output_path, self.config.experiment_id, 'train_norm.safetensors'), map_location='cpu')
         | 
| 276 | 
            +
                        collect_sd = {}
         | 
| 277 | 
            +
                        for k, v in sdd.items():
         | 
| 278 | 
            +
                            collect_sd[k[7:]] = v
         | 
| 279 | 
            +
                        train_norm.load_state_dict(collect_sd, strict=True)
         | 
| 280 | 
            +
                    
         | 
| 281 | 
            +
                   
         | 
| 282 | 
            +
                    train_norm.to(self.device).train().requires_grad_(True)
         | 
| 283 | 
            +
                    
         | 
| 284 | 
            +
                    if generator_ema is not None:
         | 
| 285 | 
            +
                        
         | 
| 286 | 
            +
                        generator_ema.load_state_dict(load_or_fail(self.config.generator_checkpoint_path))
         | 
| 287 | 
            +
                        generator_ema._init_extra_parameter()
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                      
         | 
| 290 | 
            +
                        pretrained_pth = os.path.join(self.config.output_path, self.config.experiment_id, 'generator.safetensors')
         | 
| 291 | 
            +
                        if os.path.exists(pretrained_pth):
         | 
| 292 | 
            +
                          print(pretrained_pth, 'exists')
         | 
| 293 | 
            +
                          generator_ema.load_state_dict(torch.load(pretrained_pth, map_location='cpu'))
         | 
| 294 | 
            +
                      
         | 
| 295 | 
            +
                       
         | 
| 296 | 
            +
                        generator_ema.eval().requires_grad_(False)
         | 
| 297 | 
            +
                  
         | 
| 298 | 
            +
                     
         | 
| 299 | 
            +
                        
         | 
| 300 | 
            +
                    
         | 
| 301 | 
            +
                    check_nan_inmodel(generator, 'generator')
         | 
| 302 | 
            +
                 
         | 
| 303 | 
            +
                    
         | 
| 304 | 
            +
                    
         | 
| 305 | 
            +
                    if self.config.use_ddp and self.config.training:
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                        train_norm = DDP(train_norm, device_ids=[self.device], find_unused_parameters=True)
         | 
| 308 | 
            +
                        
         | 
| 309 | 
            +
                    # CLIP encoders     
         | 
| 310 | 
            +
                    tokenizer = AutoTokenizer.from_pretrained(self.config.clip_text_model_name)
         | 
| 311 | 
            +
                    text_model = CLIPTextModelWithProjection.from_pretrained( self.config.clip_text_model_name).requires_grad_(False).to(dtype).to(self.device)
         | 
| 312 | 
            +
                    image_model = CLIPVisionModelWithProjection.from_pretrained(self.config.clip_image_model_name).requires_grad_(False).to(dtype).to(self.device)
         | 
| 313 | 
            +
                    
         | 
| 314 | 
            +
                    return self.Models(
         | 
| 315 | 
            +
                        effnet=effnet, previewer=previewer, train_norm = train_norm,
         | 
| 316 | 
            +
                        generator=generator, tokenizer=tokenizer, text_model=text_model, image_model=image_model,
         | 
| 317 | 
            +
                    )
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                def setup_optimizers(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
         | 
| 320 | 
            +
                    
         | 
| 321 | 
            +
             
         | 
| 322 | 
            +
                    params = []
         | 
| 323 | 
            +
                    params += list(models.train_norm.module.parameters())
         | 
| 324 | 
            +
                   
         | 
| 325 | 
            +
                    optimizer = optim.AdamW(params, lr=self.config.lr) 
         | 
| 326 | 
            +
             | 
| 327 | 
            +
                    return self.Optimizers(generator=optimizer)
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                def ema_update(self, ema_model, source_model, beta):
         | 
| 330 | 
            +
                    for param_src, param_ema in zip(source_model.parameters(), ema_model.parameters()):
         | 
| 331 | 
            +
                        param_ema.data.mul_(beta).add_(param_src.data, alpha = 1 - beta)
         | 
| 332 | 
            +
                        
         | 
| 333 | 
            +
                def sync_ema(self, ema_model):
         | 
| 334 | 
            +
                    for param in ema_model.parameters():
         | 
| 335 | 
            +
                        torch.distributed.all_reduce(param.data, op=torch.distributed.ReduceOp.SUM)
         | 
| 336 | 
            +
                        param.data /= torch.distributed.get_world_size()
         | 
| 337 | 
            +
                def setup_optimizers_backup(self, extras: Extras, models: Models) -> TrainingCore.Optimizers:
         | 
| 338 | 
            +
                   
         | 
| 339 | 
            +
             | 
| 340 | 
            +
                    optimizer = optim.AdamW(
         | 
| 341 | 
            +
                        models.generator.up_blocks.parameters() , 
         | 
| 342 | 
            +
                    lr=self.config.lr)
         | 
| 343 | 
            +
                    optimizer = self.load_optimizer(optimizer, 'generator_optim',
         | 
| 344 | 
            +
                                                    fsdp_model=models.generator if self.config.use_fsdp else None)
         | 
| 345 | 
            +
                    return self.Optimizers(generator=optimizer)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                def setup_schedulers(self, extras: Extras, models: Models, optimizers: TrainingCore.Optimizers) -> Schedulers:
         | 
| 348 | 
            +
                    scheduler = GradualWarmupScheduler(optimizers.generator, multiplier=1, total_epoch=self.config.warmup_updates)
         | 
| 349 | 
            +
                    scheduler.last_epoch = self.info.total_steps
         | 
| 350 | 
            +
                    return self.Schedulers(generator=scheduler)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def setup_data(self, extras: Extras) -> WarpCore.Data:
         | 
| 353 | 
            +
                    # SETUP DATASET
         | 
| 354 | 
            +
                    dataset_path = self.config.webdataset_path
         | 
| 355 | 
            +
                    dataset = mydist_dataset(dataset_path, \
         | 
| 356 | 
            +
                        torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None \
         | 
| 357 | 
            +
                            else extras.transforms)
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                    # SETUP DATALOADER
         | 
| 360 | 
            +
                    real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps)
         | 
| 361 | 
            +
                   
         | 
| 362 | 
            +
                    sampler =  DistributedSampler(dataset, rank=self.process_id, num_replicas = self.world_size, shuffle=True)
         | 
| 363 | 
            +
                    dataloader = DataLoader(
         | 
| 364 | 
            +
                        dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True,
         | 
| 365 | 
            +
                        collate_fn=identity if self.config.multi_aspect_ratio is not None else None,
         | 
| 366 | 
            +
                        sampler = sampler
         | 
| 367 | 
            +
                    )
         | 
| 368 | 
            +
                    if self.is_main_node:
         | 
| 369 | 
            +
                        print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)")
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                    if self.config.multi_aspect_ratio is not None:
         | 
| 372 | 
            +
                        aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio]
         | 
| 373 | 
            +
                        dataloader_iterator = Bucketeer(dataloader, density=[ss*ss for ss in self.config.image_size] , factor=32,
         | 
| 374 | 
            +
                                                        ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio,
         | 
| 375 | 
            +
                                                        interpolate_nearest=False)  # , use_smartcrop=True)
         | 
| 376 | 
            +
                    else:
         | 
| 377 | 
            +
                       
         | 
| 378 | 
            +
                        dataloader_iterator = iter(dataloader)
         | 
| 379 | 
            +
             | 
| 380 | 
            +
                    return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator, sampler=sampler)
         | 
| 381 | 
            +
             | 
| 382 | 
            +
             | 
| 383 | 
            +
                def  models_to_save(self):
         | 
| 384 | 
            +
                    pass
         | 
| 385 | 
            +
                def setup_ddp(self, experiment_id, single_gpu=False, rank=0):
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    if not single_gpu:
         | 
| 388 | 
            +
                        local_rank = rank
         | 
| 389 | 
            +
                        process_id = rank
         | 
| 390 | 
            +
                        world_size = get_world_size()
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                        self.process_id = process_id
         | 
| 393 | 
            +
                        self.is_main_node = process_id == 0
         | 
| 394 | 
            +
                        self.device = torch.device(local_rank)
         | 
| 395 | 
            +
                        self.world_size = world_size
         | 
| 396 | 
            +
                      
         | 
| 397 | 
            +
                        os.environ['MASTER_ADDR'] = 'localhost'
         | 
| 398 | 
            +
                        os.environ['MASTER_PORT'] = '41443'
         | 
| 399 | 
            +
                        torch.cuda.set_device(local_rank)
         | 
| 400 | 
            +
                        init_process_group(
         | 
| 401 | 
            +
                            backend="nccl",
         | 
| 402 | 
            +
                            rank=local_rank,
         | 
| 403 | 
            +
                            world_size=world_size,
         | 
| 404 | 
            +
                        )
         | 
| 405 | 
            +
                        print(f"[GPU {process_id}] READY")
         | 
| 406 | 
            +
                    else:
         | 
| 407 | 
            +
                        self.is_main_node = rank == 0
         | 
| 408 | 
            +
                        self.process_id = rank
         | 
| 409 | 
            +
                        self.device = torch.device('cuda:0')
         | 
| 410 | 
            +
                        self.world_size = 1
         | 
| 411 | 
            +
                        print("Running in single thread, DDP not enabled.")
         | 
| 412 | 
            +
                # Training loop --------------------------------
         | 
| 413 | 
            +
                def get_target_lr_size(self, ratio, std_size=24):
         | 
| 414 | 
            +
                    w, h = int(std_size / math.sqrt(ratio)), int(std_size * math.sqrt(ratio)) 
         | 
| 415 | 
            +
                    return (h * 32 , w * 32) 
         | 
| 416 | 
            +
                def forward_pass(self, data: WarpCore.Data, extras: Extras, models: Models):
         | 
| 417 | 
            +
                    #batch = next(data.iterator)
         | 
| 418 | 
            +
                    batch = data
         | 
| 419 | 
            +
                    ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
         | 
| 420 | 
            +
                    shape_lr = self.get_target_lr_size(ratio)
         | 
| 421 | 
            +
                    #print('in line 485', shape_lr, ratio, batch['images'].shape)
         | 
| 422 | 
            +
                    with torch.no_grad():
         | 
| 423 | 
            +
                        conditions = self.get_conditions(batch, models, extras)
         | 
| 424 | 
            +
                        
         | 
| 425 | 
            +
                        latents = self.encode_latents(batch, models, extras)
         | 
| 426 | 
            +
                        latents_lr = self.encode_latents(batch, models, extras,target_size=shape_lr)
         | 
| 427 | 
            +
                        
         | 
| 428 | 
            +
                        noised, noise, target, logSNR, noise_cond, loss_weight = extras.gdf.diffuse(latents, shift=1, loss_shift=1)
         | 
| 429 | 
            +
                        noised_lr, noise_lr, target_lr, logSNR_lr, noise_cond_lr, loss_weight_lr = extras.gdf.diffuse(latents_lr, shift=1, loss_shift=1, t=torch.ones(latents.shape[0]).to(latents.device)*0.05, )
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    with torch.cuda.amp.autocast(dtype=torch.bfloat16): 
         | 
| 432 | 
            +
                        # 768 1536
         | 
| 433 | 
            +
                        require_cond = True
         | 
| 434 | 
            +
                      
         | 
| 435 | 
            +
                        with torch.no_grad():
         | 
| 436 | 
            +
                            _, lr_enc_guide, lr_dec_guide = models.generator(noised_lr, noise_cond_lr, reuire_f=True, **conditions)
         | 
| 437 | 
            +
                        
         | 
| 438 | 
            +
                        
         | 
| 439 | 
            +
                        pred = models.generator(noised, noise_cond, reuire_f=False, lr_guide=(lr_enc_guide, lr_dec_guide) if require_cond else None , **conditions)             
         | 
| 440 | 
            +
                        loss = nn.functional.mse_loss(pred, target, reduction='none').mean(dim=[1, 2, 3]) 
         | 
| 441 | 
            +
                       
         | 
| 442 | 
            +
                        loss_adjusted = (loss * loss_weight ).mean() / self.config.grad_accum_steps 
         | 
| 443 | 
            +
                      
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
         | 
| 446 | 
            +
                        extras.gdf.loss_weight.update_buckets(logSNR, loss)
         | 
| 447 | 
            +
             | 
| 448 | 
            +
                    return loss,  loss_adjusted
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                def backward_pass(self, update, loss_adjusted, models: Models, optimizers: TrainingCore.Optimizers, schedulers: Schedulers):
         | 
| 451 | 
            +
                   
         | 
| 452 | 
            +
                    
         | 
| 453 | 
            +
                    if update:
         | 
| 454 | 
            +
                      
         | 
| 455 | 
            +
                        torch.distributed.barrier()
         | 
| 456 | 
            +
                        loss_adjusted.backward()
         | 
| 457 | 
            +
                        
         | 
| 458 | 
            +
                        grad_norm = nn.utils.clip_grad_norm_(models.train_norm.module.parameters(), 1.0)
         | 
| 459 | 
            +
                       
         | 
| 460 | 
            +
                        optimizers_dict = optimizers.to_dict()
         | 
| 461 | 
            +
                        for k in optimizers_dict:
         | 
| 462 | 
            +
                            if k != 'training':
         | 
| 463 | 
            +
                                optimizers_dict[k].step()
         | 
| 464 | 
            +
                        schedulers_dict = schedulers.to_dict()
         | 
| 465 | 
            +
                        for k in schedulers_dict:
         | 
| 466 | 
            +
                            if k != 'training':
         | 
| 467 | 
            +
                                schedulers_dict[k].step()
         | 
| 468 | 
            +
                        for k in optimizers_dict:
         | 
| 469 | 
            +
                            if k != 'training':
         | 
| 470 | 
            +
                                optimizers_dict[k].zero_grad(set_to_none=True)
         | 
| 471 | 
            +
                        self.info.total_steps += 1
         | 
| 472 | 
            +
                    else:
         | 
| 473 | 
            +
                       
         | 
| 474 | 
            +
                        loss_adjusted.backward()
         | 
| 475 | 
            +
                       
         | 
| 476 | 
            +
                        grad_norm = torch.tensor(0.0).to(self.device)
         | 
| 477 | 
            +
                    
         | 
| 478 | 
            +
                    return grad_norm
         | 
| 479 | 
            +
             | 
| 480 | 
            +
             | 
| 481 | 
            +
                def encode_latents(self, batch: dict, models: Models, extras: Extras, target_size=None) -> torch.Tensor:
         | 
| 482 | 
            +
                    
         | 
| 483 | 
            +
                    images = batch['images'].to(self.device)
         | 
| 484 | 
            +
                    if target_size is not None:
         | 
| 485 | 
            +
                      images = F.interpolate(images, target_size)
         | 
| 486 | 
            +
                      
         | 
| 487 | 
            +
                    return models.effnet(extras.effnet_preprocess(images))
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor:
         | 
| 490 | 
            +
                    return models.previewer(latents)
         | 
| 491 | 
            +
             | 
| 492 | 
            +
                def __init__(self, rank=0, config_file_path=None, config_dict=None, device="cpu", training=True, world_size=1, ):
         | 
| 493 | 
            +
             | 
| 494 | 
            +
                    self.is_main_node = (rank == 0)
         | 
| 495 | 
            +
                    self.config: self.Config = self.setup_config(config_file_path, config_dict, training)
         | 
| 496 | 
            +
                    self.setup_ddp(self.config.experiment_id, single_gpu=world_size <= 1, rank=rank)
         | 
| 497 | 
            +
                    self.info: self.Info = self.setup_info()
         | 
| 498 | 
            +
                    
         | 
| 499 | 
            +
                   
         | 
| 500 | 
            +
                    
         | 
| 501 | 
            +
                def __call__(self, single_gpu=False):
         | 
| 502 | 
            +
                    
         | 
| 503 | 
            +
                    if self.config.allow_tf32:
         | 
| 504 | 
            +
                        torch.backends.cuda.matmul.allow_tf32 = True
         | 
| 505 | 
            +
                        torch.backends.cudnn.allow_tf32 = True
         | 
| 506 | 
            +
             | 
| 507 | 
            +
                    if self.is_main_node:
         | 
| 508 | 
            +
                        print()
         | 
| 509 | 
            +
                        print("**STARTIG JOB WITH CONFIG:**")
         | 
| 510 | 
            +
                        print(yaml.dump(self.config.to_dict(), default_flow_style=False))
         | 
| 511 | 
            +
                        print("------------------------------------")
         | 
| 512 | 
            +
                        print()
         | 
| 513 | 
            +
                        print("**INFO:**")
         | 
| 514 | 
            +
                        print(yaml.dump(vars(self.info), default_flow_style=False))
         | 
| 515 | 
            +
                        print("------------------------------------")
         | 
| 516 | 
            +
                        print()
         | 
| 517 | 
            +
                    
         | 
| 518 | 
            +
                    # SETUP STUFF
         | 
| 519 | 
            +
                    extras = self.setup_extras_pre()
         | 
| 520 | 
            +
                    assert extras is not None, "setup_extras_pre() must return a DTO"
         | 
| 521 | 
            +
             | 
| 522 | 
            +
             | 
| 523 | 
            +
             | 
| 524 | 
            +
                    data = self.setup_data(extras)
         | 
| 525 | 
            +
                    assert data is not None, "setup_data() must return a DTO"
         | 
| 526 | 
            +
                    if self.is_main_node:
         | 
| 527 | 
            +
                        print("**DATA:**")
         | 
| 528 | 
            +
                        print(yaml.dump({k:type(v).__name__ for k, v in data.to_dict().items()}, default_flow_style=False))
         | 
| 529 | 
            +
                        print("------------------------------------")
         | 
| 530 | 
            +
                        print()
         | 
| 531 | 
            +
             | 
| 532 | 
            +
                    models = self.setup_models(extras)
         | 
| 533 | 
            +
                    assert models is not None, "setup_models() must return a DTO"
         | 
| 534 | 
            +
                    if self.is_main_node:
         | 
| 535 | 
            +
                        print("**MODELS:**")
         | 
| 536 | 
            +
                        print(yaml.dump({
         | 
| 537 | 
            +
                            k:f"{type(v).__name__} - {f'trainable params {sum(p.numel() for p in v.parameters() if p.requires_grad)}' if isinstance(v, nn.Module) else 'Not a nn.Module'}" for k, v in models.to_dict().items()
         | 
| 538 | 
            +
                        }, default_flow_style=False))
         | 
| 539 | 
            +
                        print("------------------------------------")
         | 
| 540 | 
            +
                        print()
         | 
| 541 | 
            +
             | 
| 542 | 
            +
             | 
| 543 | 
            +
             | 
| 544 | 
            +
                    optimizers = self.setup_optimizers(extras, models)
         | 
| 545 | 
            +
                    assert optimizers is not None, "setup_optimizers() must return a DTO"
         | 
| 546 | 
            +
                    if self.is_main_node:
         | 
| 547 | 
            +
                        print("**OPTIMIZERS:**")
         | 
| 548 | 
            +
                        print(yaml.dump({k:type(v).__name__ for k, v in optimizers.to_dict().items()}, default_flow_style=False))
         | 
| 549 | 
            +
                        print("------------------------------------")
         | 
| 550 | 
            +
                        print()
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                    schedulers = self.setup_schedulers(extras, models, optimizers)
         | 
| 553 | 
            +
                    assert schedulers is not None, "setup_schedulers() must return a DTO"
         | 
| 554 | 
            +
                    if self.is_main_node:
         | 
| 555 | 
            +
                        print("**SCHEDULERS:**")
         | 
| 556 | 
            +
                        print(yaml.dump({k:type(v).__name__ for k, v in schedulers.to_dict().items()}, default_flow_style=False))
         | 
| 557 | 
            +
                        print("------------------------------------")
         | 
| 558 | 
            +
                        print()
         | 
| 559 | 
            +
             | 
| 560 | 
            +
                    post_extras =self.setup_extras_post(extras, models, optimizers, schedulers)
         | 
| 561 | 
            +
                    assert post_extras is not None, "setup_extras_post() must return a DTO"
         | 
| 562 | 
            +
                    extras = self.Extras.from_dict({ **extras.to_dict(),**post_extras.to_dict() })
         | 
| 563 | 
            +
                    if self.is_main_node:
         | 
| 564 | 
            +
                        print("**EXTRAS:**")
         | 
| 565 | 
            +
                        print(yaml.dump({k:f"{v}" for k, v in extras.to_dict().items()}, default_flow_style=False))
         | 
| 566 | 
            +
                        print("------------------------------------")
         | 
| 567 | 
            +
                        print()
         | 
| 568 | 
            +
                    # -------
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    # TRAIN
         | 
| 571 | 
            +
                    if self.is_main_node:
         | 
| 572 | 
            +
                        print("**TRAINING STARTING...**")
         | 
| 573 | 
            +
                    self.train(data, extras, models, optimizers, schedulers)
         | 
| 574 | 
            +
             | 
| 575 | 
            +
                    if single_gpu is False:
         | 
| 576 | 
            +
                        barrier()
         | 
| 577 | 
            +
                        destroy_process_group()
         | 
| 578 | 
            +
                    if self.is_main_node:
         | 
| 579 | 
            +
                        print()
         | 
| 580 | 
            +
                        print("------------------------------------")
         | 
| 581 | 
            +
                        print()
         | 
| 582 | 
            +
                        print("**TRAINING COMPLETE**")
         | 
| 583 | 
            +
                       
         | 
| 584 | 
            +
             | 
| 585 | 
            +
             | 
| 586 | 
            +
                def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: TrainingCore.Optimizers,
         | 
| 587 | 
            +
                          schedulers: WarpCore.Schedulers):
         | 
| 588 | 
            +
                    start_iter = self.info.iter + 1
         | 
| 589 | 
            +
                    max_iters = self.config.updates * self.config.grad_accum_steps
         | 
| 590 | 
            +
                    if self.is_main_node:
         | 
| 591 | 
            +
                        print(f"STARTING AT STEP: {start_iter}/{max_iters}")
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                 
         | 
| 594 | 
            +
                    if self.is_main_node:
         | 
| 595 | 
            +
                        create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/')
         | 
| 596 | 
            +
                    
         | 
| 597 | 
            +
                    models.generator.train()
         | 
| 598 | 
            +
                 
         | 
| 599 | 
            +
                    iter_cnt = 0
         | 
| 600 | 
            +
                    epoch_cnt = 0
         | 
| 601 | 
            +
                    models.train_norm.train()
         | 
| 602 | 
            +
                    while True:
         | 
| 603 | 
            +
                      epoch_cnt += 1
         | 
| 604 | 
            +
                      if self.world_size > 1:
         | 
| 605 | 
            +
                        
         | 
| 606 | 
            +
                        data.sampler.set_epoch(epoch_cnt)  
         | 
| 607 | 
            +
                      for ggg in range(len(data.dataloader)):
         | 
| 608 | 
            +
                          iter_cnt += 1
         | 
| 609 | 
            +
                          loss, loss_adjusted = self.forward_pass(next(data.iterator), extras, models)
         | 
| 610 | 
            +
                          grad_norm = self.backward_pass(
         | 
| 611 | 
            +
                                    iter_cnt % self.config.grad_accum_steps == 0 or iter_cnt == max_iters, loss_adjusted,
         | 
| 612 | 
            +
                                    models, optimizers, schedulers
         | 
| 613 | 
            +
                                  )
         | 
| 614 | 
            +
             | 
| 615 | 
            +
                          self.info.iter = iter_cnt
         | 
| 616 | 
            +
                          
         | 
| 617 | 
            +
                         
         | 
| 618 | 
            +
                          # UPDATE LOSS METRICS
         | 
| 619 | 
            +
                          self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01
         | 
| 620 | 
            +
              
         | 
| 621 | 
            +
                          #print('in line 666 after ema loss', grad_norm, loss.mean().item(), iter_cnt, self.info.ema_loss)
         | 
| 622 | 
            +
                          if self.is_main_node and  np.isnan(loss.mean().item()) or np.isnan(grad_norm.item()):
         | 
| 623 | 
            +
                                  print(f" NaN value encountered in training run {self.info.wandb_run_id}", \
         | 
| 624 | 
            +
                                  f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
         | 
| 625 | 
            +
              
         | 
| 626 | 
            +
                          if self.is_main_node:
         | 
| 627 | 
            +
                              logs = {
         | 
| 628 | 
            +
                                  'loss': self.info.ema_loss,
         | 
| 629 | 
            +
                                  'backward_loss': loss_adjusted.mean().item(),
         | 
| 630 | 
            +
                                  'ema_loss': self.info.ema_loss,
         | 
| 631 | 
            +
                                  'raw_ori_loss': loss.mean().item(),
         | 
| 632 | 
            +
                                  'grad_norm': grad_norm.item(),
         | 
| 633 | 
            +
                                  'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0,
         | 
| 634 | 
            +
                                  'total_steps': self.info.total_steps,
         | 
| 635 | 
            +
                              }
         | 
| 636 | 
            +
                              if iter_cnt % (self.config.save_every) == 0:
         | 
| 637 | 
            +
                                    
         | 
| 638 | 
            +
                                  print(iter_cnt, max_iters, logs, epoch_cnt, )
         | 
| 639 | 
            +
                              
         | 
| 640 | 
            +
                              
         | 
| 641 | 
            +
                          
         | 
| 642 | 
            +
                          if iter_cnt == 1 or iter_cnt % (self.config.save_every  ) == 0 or iter_cnt == max_iters:
         | 
| 643 | 
            +
                         
         | 
| 644 | 
            +
                              # SAVE AND CHECKPOINT STUFF
         | 
| 645 | 
            +
                              if np.isnan(loss.mean().item()):
         | 
| 646 | 
            +
                                  if self.is_main_node and self.config.wandb_project is not None:
         | 
| 647 | 
            +
                                      print(f"NaN value encountered in training run {self.info.wandb_run_id}", \
         | 
| 648 | 
            +
                                      f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}")
         | 
| 649 | 
            +
                                 
         | 
| 650 | 
            +
                              else:
         | 
| 651 | 
            +
                                  if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight):
         | 
| 652 | 
            +
                                      self.info.adaptive_loss = {
         | 
| 653 | 
            +
                                          'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(),
         | 
| 654 | 
            +
                                          'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(),
         | 
| 655 | 
            +
                                      }
         | 
| 656 | 
            +
                                 
         | 
| 657 | 
            +
                                  
         | 
| 658 | 
            +
                                  
         | 
| 659 | 
            +
                                  if self.is_main_node and iter_cnt % (self.config.save_every * self.config.grad_accum_steps) == 0:
         | 
| 660 | 
            +
                                      print('save model', iter_cnt, iter_cnt % (self.config.save_every * self.config.grad_accum_steps), self.config.save_every, self.config.grad_accum_steps )
         | 
| 661 | 
            +
                                      torch.save(models.train_norm.state_dict(), \
         | 
| 662 | 
            +
                                      f'{self.config.output_path}/{self.config.experiment_id}/train_norm.safetensors')
         | 
| 663 | 
            +
             | 
| 664 | 
            +
                                      torch.save(models.train_norm.state_dict(), \
         | 
| 665 | 
            +
                                          f'{self.config.output_path}/{self.config.experiment_id}/train_norm_{iter_cnt}.safetensors')
         | 
| 666 | 
            +
                                      
         | 
| 667 | 
            +
                                   
         | 
| 668 | 
            +
                          if iter_cnt == 1 or iter_cnt % (self.config.save_every* self.config.grad_accum_steps) == 0 or iter_cnt == max_iters:
         | 
| 669 | 
            +
                              
         | 
| 670 | 
            +
                              if self.is_main_node:
         | 
| 671 | 
            +
                                 
         | 
| 672 | 
            +
                                 self.sample(models, data, extras)
         | 
| 673 | 
            +
                        
         | 
| 674 | 
            +
                     
         | 
| 675 | 
            +
                      if self.info.iter >= max_iters:
         | 
| 676 | 
            +
                        break
         | 
| 677 | 
            +
                        
         | 
| 678 | 
            +
                def sample(self, models: Models, data: WarpCore.Data, extras: Extras):
         | 
| 679 | 
            +
                   
         | 
| 680 | 
            +
                   
         | 
| 681 | 
            +
                    models.generator.eval()
         | 
| 682 | 
            +
                    models.train_norm.eval()
         | 
| 683 | 
            +
                    with torch.no_grad():
         | 
| 684 | 
            +
                        batch = next(data.iterator)
         | 
| 685 | 
            +
                        ratio = batch['images'].shape[-2] / batch['images'].shape[-1]
         | 
| 686 | 
            +
                       
         | 
| 687 | 
            +
                        shape_lr = self.get_target_lr_size(ratio)
         | 
| 688 | 
            +
                        conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
         | 
| 689 | 
            +
                        unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                        latents = self.encode_latents(batch, models, extras)
         | 
| 692 | 
            +
                        latents_lr = self.encode_latents(batch, models, extras, target_size = shape_lr)
         | 
| 693 | 
            +
                       
         | 
| 694 | 
            +
                        
         | 
| 695 | 
            +
                        if self.is_main_node:
         | 
| 696 | 
            +
                            
         | 
| 697 | 
            +
                            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
         | 
| 698 | 
            +
                                
         | 
| 699 | 
            +
                                *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
         | 
| 700 | 
            +
                                    models.generator, conditions,
         | 
| 701 | 
            +
                                    latents.shape, latents_lr.shape, 
         | 
| 702 | 
            +
                                    unconditions, device=self.device, **extras.sampling_configs
         | 
| 703 | 
            +
                                )
         | 
| 704 | 
            +
                
         | 
| 705 | 
            +
                               
         | 
| 706 | 
            +
                        
         | 
| 707 | 
            +
                        
         | 
| 708 | 
            +
                        if self.is_main_node:
         | 
| 709 | 
            +
                            print('sampling results hr latent shape', latents.shape, 'lr latent shape', latents_lr.shape, )
         | 
| 710 | 
            +
                            noised_images = torch.cat(
         | 
| 711 | 
            +
                                [self.decode_latents(latents[i:i + 1].float(), batch, models, extras) for i in range(len(latents))], dim=0)
         | 
| 712 | 
            +
                            
         | 
| 713 | 
            +
                            sampled_images = torch.cat(
         | 
| 714 | 
            +
                                [self.decode_latents(sampled[i:i + 1].float(), batch, models, extras) for i in range(len(sampled))], dim=0)
         | 
| 715 | 
            +
             | 
| 716 | 
            +
                                
         | 
| 717 | 
            +
                            noised_images_lr = torch.cat(
         | 
| 718 | 
            +
                                [self.decode_latents(latents_lr[i:i + 1].float(), batch, models, extras) for i in range(len(latents_lr))], dim=0)
         | 
| 719 | 
            +
                            
         | 
| 720 | 
            +
                            sampled_images_lr = torch.cat(
         | 
| 721 | 
            +
                                [self.decode_latents(sampled_lr[i:i + 1].float(), batch, models, extras) for i in range(len(sampled_lr))], dim=0)
         | 
| 722 | 
            +
             | 
| 723 | 
            +
                            images = batch['images']
         | 
| 724 | 
            +
                            if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2):
         | 
| 725 | 
            +
                                images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic')
         | 
| 726 | 
            +
                                images_lr = nn.functional.interpolate(images, size=noised_images_lr.shape[-2:], mode='bicubic')
         | 
| 727 | 
            +
             | 
| 728 | 
            +
                            collage_img = torch.cat([
         | 
| 729 | 
            +
                                torch.cat([i for i in images.cpu()], dim=-1),
         | 
| 730 | 
            +
                                torch.cat([i for i in noised_images.cpu()], dim=-1),
         | 
| 731 | 
            +
                                torch.cat([i for i in sampled_images.cpu()], dim=-1),
         | 
| 732 | 
            +
                            ], dim=-2)
         | 
| 733 | 
            +
                            
         | 
| 734 | 
            +
                            collage_img_lr = torch.cat([
         | 
| 735 | 
            +
                                torch.cat([i for i in images_lr.cpu()], dim=-1),
         | 
| 736 | 
            +
                                torch.cat([i for i in noised_images_lr.cpu()], dim=-1),
         | 
| 737 | 
            +
                                torch.cat([i for i in sampled_images_lr.cpu()], dim=-1),
         | 
| 738 | 
            +
                            ], dim=-2)
         | 
| 739 | 
            +
             | 
| 740 | 
            +
                            torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg')
         | 
| 741 | 
            +
                            torchvision.utils.save_image(collage_img_lr, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}_lr.jpg')
         | 
| 742 | 
            +
                           
         | 
| 743 | 
            +
                       
         | 
| 744 | 
            +
                        models.generator.train()
         | 
| 745 | 
            +
                        models.train_norm.train()
         | 
| 746 | 
            +
                        print('finish sampling')
         | 
| 747 | 
            +
                
         | 
| 748 | 
            +
                
         | 
| 749 | 
            +
                
         | 
| 750 | 
            +
                def sample_fortest(self, models: Models, extras: Extras, hr_shape, lr_shape, batch, eval_image_embeds=False):
         | 
| 751 | 
            +
                   
         | 
| 752 | 
            +
                  
         | 
| 753 | 
            +
                    models.generator.eval()
         | 
| 754 | 
            +
                    
         | 
| 755 | 
            +
                    with torch.no_grad():
         | 
| 756 | 
            +
                       
         | 
| 757 | 
            +
                        if self.is_main_node:
         | 
| 758 | 
            +
                            conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=eval_image_embeds)
         | 
| 759 | 
            +
                            unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)
         | 
| 760 | 
            +
                           
         | 
| 761 | 
            +
                            with torch.cuda.amp.autocast(dtype=torch.bfloat16):
         | 
| 762 | 
            +
                               
         | 
| 763 | 
            +
                                *_, (sampled, _, _, sampled_lr) = extras.gdf.sample(
         | 
| 764 | 
            +
                                    models.generator, conditions,
         | 
| 765 | 
            +
                                    hr_shape, lr_shape, 
         | 
| 766 | 
            +
                                    unconditions, device=self.device, **extras.sampling_configs
         | 
| 767 | 
            +
                                )
         | 
| 768 | 
            +
                
         | 
| 769 | 
            +
                                if models.generator_ema is not None:
         | 
| 770 | 
            +
                                    
         | 
| 771 | 
            +
                                    *_, (sampled_ema, _, _, sampled_ema_lr) = extras.gdf.sample(
         | 
| 772 | 
            +
                                        models.generator_ema,  conditions,
         | 
| 773 | 
            +
                                        latents.shape, latents_lr.shape, 
         | 
| 774 | 
            +
                                        unconditions, device=self.device, **extras.sampling_configs
         | 
| 775 | 
            +
                                    )
         | 
| 776 | 
            +
                                   
         | 
| 777 | 
            +
                                else:
         | 
| 778 | 
            +
                                    sampled_ema = sampled
         | 
| 779 | 
            +
                                    sampled_ema_lr = sampled_lr
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                    return sampled, sampled_lr
         | 
| 782 | 
            +
            def main_worker(rank, cfg):
         | 
| 783 | 
            +
                print("Launching Script in main worker")
         | 
| 784 | 
            +
               
         | 
| 785 | 
            +
                warpcore = WurstCore(
         | 
| 786 | 
            +
                    config_file_path=cfg, rank=rank, world_size = get_world_size()
         | 
| 787 | 
            +
                )
         | 
| 788 | 
            +
                # core.fsdp_defaults['sharding_strategy'] = ShardingStrategy.NO_SHARD
         | 
| 789 | 
            +
             | 
| 790 | 
            +
                # RUN TRAINING
         | 
| 791 | 
            +
                warpcore(get_world_size()==1)
         | 
| 792 | 
            +
             | 
| 793 | 
            +
            if __name__ == '__main__':
         | 
| 794 | 
            +
                print('launch multi process')
         | 
| 795 | 
            +
                # os.environ["OMP_NUM_THREADS"] = "1" 
         | 
| 796 | 
            +
                # os.environ["MKL_NUM_THREADS"] = "1" 
         | 
| 797 | 
            +
                #dist.init_process_group(backend="nccl")
         | 
| 798 | 
            +
                #torch.backends.cudnn.benchmark = True
         | 
| 799 | 
            +
            #train/train_c_my.py
         | 
| 800 | 
            +
                #mp.set_sharing_strategy('file_system')
         | 
| 801 | 
            +
             | 
| 802 | 
            +
                if get_master_ip() == "127.0.0.1":
         | 
| 803 | 
            +
                    # manually launch distributed processes
         | 
| 804 | 
            +
                    mp.spawn(main_worker, nprocs=get_world_size(), args=(sys.argv[1] if len(sys.argv) > 1 else None, ))
         | 
| 805 | 
            +
                else:
         | 
| 806 | 
            +
                    main_worker(0, sys.argv[1] if len(sys.argv) > 1 else None, )
         | 
|  | 
