Spaces:
Runtime error
Runtime error
| """ Main training script """ | |
| import argparse | |
| import copy | |
| import glob | |
| import os | |
| import random | |
| import functools | |
| import numpy as np | |
| import torch | |
| # torch.multiprocessing.set_sharing_strategy('file_system') | |
| import wandb | |
| from data2 import get_data | |
| from distributed import init_distributed_device, world_info_from_env | |
| from torch.distributed.fsdp import ( | |
| FullyShardedDataParallel as FSDP, | |
| MixedPrecision, | |
| BackwardPrefetch, | |
| ShardingStrategy, | |
| FullStateDictConfig, | |
| CPUOffload, | |
| StateDictType, | |
| ) | |
| from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler | |
| from torch.distributed.fsdp.wrap import ( | |
| transformer_auto_wrap_policy, | |
| enable_wrap, | |
| wrap, | |
| ) | |
| from train_utils import train_one_epoch | |
| from transformers import ( | |
| get_constant_schedule_with_warmup, | |
| get_cosine_schedule_with_warmup, | |
| get_linear_schedule_with_warmup, | |
| ) | |
| from open_flamingo import create_model_and_transforms | |
| from torch.utils.tensorboard import SummaryWriter | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| from torch.cuda.amp import GradScaler | |
| from torch.distributed.optim import ZeroRedundancyOptimizer | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s %(message)s', | |
| datefmt='%m/%d %I:%M:%S', | |
| ) | |
| class FakeDataloader: | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| return None | |
| def random_seed(seed=42, rank=0): | |
| torch.manual_seed(seed + rank) | |
| np.random.seed(seed + rank) | |
| random.seed(seed + rank) | |
| def get_grouped_params(model, args): | |
| params_with_wd, params_without_wd = [], [] | |
| def apply_decay(x): | |
| x = x.lower() | |
| return "norm" not in x and "bn" not in x and "bias" not in x and "embed" not in x and "wte" not in x and "flat_param" not in x | |
| for n, p in model.named_parameters(): | |
| # if p.requires_grad: | |
| if apply_decay(n): | |
| if torch.distributed.get_rank() == 0: | |
| logging.info(f"with wd: {n}") | |
| params_with_wd.append(p) | |
| else: | |
| if torch.distributed.get_rank() == 0: | |
| logging.info(f"without wd: {n}") | |
| params_without_wd.append(p) | |
| return [ | |
| {"params": params_with_wd, "weight_decay": args.weight_decay}, | |
| {"params": params_without_wd, "weight_decay": 0.0}, | |
| ] | |
| def lambda_policy_fn(module): | |
| if ( | |
| len(list(module.named_children())) == 0 | |
| and getattr(module, "weight", None) is not None | |
| and module.weight.requires_grad | |
| ): | |
| return True | |
| return False | |
| def lambda_auto_wrap_policy( | |
| module: torch.nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn, | |
| ) -> bool: | |
| """ | |
| A convenient auto wrap policy to wrap submodules based on an arbitrary user | |
| function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as | |
| a `wrapper_cls` unit. | |
| Return if a module should be wrapped during auto wrapping. | |
| The first three parameters are required by :func:`_recursive_wrap`. | |
| Args: | |
| module (nn.Module): Current module being considered. | |
| recurse (bool): If ``False``, then this function must decide whether | |
| ``module`` should be wrapped as an FSDP instance or not. If | |
| ``True``, then the function is still recursing down the module | |
| tree as a part of the DFS. | |
| nonwrapped_numel (int): Parameter numel not yet wrapped. | |
| lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then | |
| this module will be wrapped. | |
| """ | |
| if recurse: | |
| return True # always recurse | |
| return lambda_fn(module) | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--vision_encoder_path", default="ViT-B-16", type=str) | |
| parser.add_argument("--vision_encoder_pretrained", default="laion2b_s34b_b88k", type=str) | |
| parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str) | |
| parser.add_argument( | |
| "--tokenizer_path", | |
| default="facebook/opt-1.3b", | |
| type=str, | |
| help="path to tokenizer", | |
| ) | |
| parser.add_argument( | |
| "--run_name", | |
| type=str, | |
| default="openflamingo3B", | |
| help="used to name saving directory and wandb run", | |
| ) | |
| parser.add_argument("--use_media_placement_augmentation", action="store_true") | |
| parser.add_argument("--offline", action="store_true") | |
| parser.add_argument("--num_steps", type=int, default=300000) | |
| parser.add_argument( | |
| "--logging_steps", type=int, default=10, help="log loss every n steps" | |
| ) | |
| # Sum of gradient optimization batch size | |
| parser.add_argument("--batch_size_mmc4", type=int, default=128) | |
| parser.add_argument("--batch_size_laion", type=int, default=128) | |
| parser.add_argument("--batch_size_pile", type=int, default=128) | |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=1) | |
| parser.add_argument( | |
| "--resume_from_checkpoint", | |
| type=str, | |
| help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states", | |
| default=None, | |
| ) | |
| parser.add_argument( | |
| "--delete_previous_checkpoint", | |
| action="store_true", | |
| help="delete previous checkpoint when saving new checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--laion_shards", | |
| type=str, | |
| help="path to laion shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
| ) | |
| parser.add_argument( | |
| "--mmc4_shards", | |
| type=str, | |
| help="path to c4 shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
| ) | |
| parser.add_argument( | |
| "--pile_shards", | |
| type=str, | |
| default=None, | |
| help="path to pile shards, this should be a glob pattern such as /path/to/shards/shard-{0000..0999}.tar", | |
| ) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--learning_rate", default=1e-4, type=float) | |
| parser.add_argument( | |
| "--lr_scheduler", | |
| default="constant", | |
| type=str, | |
| help="constant, linear, or cosine", | |
| ) | |
| parser.add_argument("--loss_multiplier_mmc4", type=float, default=1.0) | |
| parser.add_argument("--loss_multiplier_laion", type=float, default=1.0) | |
| parser.add_argument("--loss_multiplier_pile", type=float, default=1.0) | |
| parser.add_argument("--loss_multiplier_det", type=float, default=1.0) | |
| parser.add_argument("--loss_multiplier_rel", type=float, default=1.0) | |
| parser.add_argument("--loss_multiplier_attn", type=float, default=1.0) | |
| parser.add_argument("--warmup_steps", default=5000, type=int) | |
| # weight decay is only apply to YOLOX head if using FSDP | |
| # https://medium.com/@huanghaian123/optimize-and-accelerate-yolox-with-rtmdet-hyps-in-mmyolo-80fc06d61159 | |
| parser.add_argument("--weight_decay", default=0.05, type=float) | |
| parser.add_argument( | |
| "--precision", | |
| choices=["amp_fp16", "amp_bf16", "amp_bfloat16", "bf16", "fp16", "fp32"], | |
| default="fp32", | |
| help="Floating point precision.", | |
| ) | |
| # data args | |
| parser.add_argument("--workers", type=int, default=1) | |
| parser.add_argument("--dataset_resampled", action="store_true") | |
| # distributed training args | |
| parser.add_argument( | |
| "--dist-url", | |
| default="env://", | |
| type=str, | |
| help="url used to set up distributed training", | |
| ) | |
| parser.add_argument( | |
| "--dist-backend", default="nccl", type=str, help="distributed backend" | |
| ) | |
| parser.add_argument( | |
| "--horovod", | |
| default=False, | |
| action="store_true", | |
| help="Use horovod for distributed training.", | |
| ) | |
| parser.add_argument( | |
| "--no-set-device-rank", | |
| default=False, | |
| action="store_true", | |
| help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", | |
| ) | |
| # wandb args | |
| parser.add_argument("--report_to_wandb", default=False, action="store_true") | |
| parser.add_argument( | |
| "--wandb_project", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--wandb_entity", | |
| type=str, | |
| ) | |
| parser.add_argument( | |
| "--save_checkpoints_to_wandb", | |
| default=False, | |
| action="store_true", | |
| help="save checkpoints to wandb", | |
| ) | |
| parser.add_argument( | |
| "--checkpoint_activations", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--freeze_vision_encoder", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--mmc4_textsim_threshold", | |
| default=30, | |
| type=float, | |
| help="threshold for filtering images in mmc4 based on image-text similarity", | |
| ) | |
| parser.add_argument( | |
| "--location_token_num", | |
| default=1000, | |
| type=int, | |
| ) | |
| parser.add_argument( | |
| "--vis_embed_size", | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--save_interval", | |
| default=1000, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--skip_delete_pattern", | |
| default=1500, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--ddp", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--pile_freq", | |
| default=1, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--restart", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--lora", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--lora_r", | |
| default=16, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--single", | |
| default=False, | |
| action="store_true", | |
| ) | |
| # Finetune | |
| parser.add_argument( | |
| "--instruct", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--fix-ffn", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--prob_ground", | |
| default=1.0, | |
| type=float, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--optimizer", | |
| default="adamw", | |
| type=str, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--add_visual_token", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--use_format_v2", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--use_sam", | |
| default=None, | |
| type=str, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--max-length", | |
| default=608, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--image-size", | |
| default=256, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--reset_llm", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--add_box", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--add_pe", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--only_grounded_sample", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--expand", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--delete_contained", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--relation", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--attn_reg", | |
| default="l1", | |
| type=str, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--enhance_data", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--no_visual", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--no_previsual", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--roi_align", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--roi_output_size", | |
| default=4, | |
| type=int, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--apply_mask", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--longer_previsual", | |
| default=False, | |
| action="store_true", | |
| ) | |
| args = parser.parse_args() | |
| assert not args.use_media_placement_augmentation, "Do not enable use_media_placement_augmentation" | |
| if args.no_previsual: | |
| assert args.no_visual, "no_previsual MUST come with no_visual" | |
| assert not args.enhance_data, "dont enable enhance_data" | |
| if args.offline: | |
| os.environ["WANDB_MODE"] = "offline" | |
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | |
| args.local_rank, args.rank, args.world_size = world_info_from_env() | |
| print(f"local_rank: {args.local_rank} rank: {args.rank} world_size: {args.world_size}") | |
| device_id = init_distributed_device(args) | |
| random_seed(args.seed) | |
| model, image_processor, tokenizer, args.vis_embed_size = create_model_and_transforms( | |
| args.vision_encoder_path, | |
| args.vision_encoder_pretrained, | |
| args.lm_path, | |
| args.tokenizer_path if args.tokenizer_path else args.lm_path, | |
| use_local_files=args.offline, | |
| use_media_placement_augmentation=args.use_media_placement_augmentation, | |
| checkpoint_activations=args.checkpoint_activations, | |
| freeze_vision_encoder=args.freeze_vision_encoder, | |
| location_token_num=args.location_token_num, | |
| lora=args.lora, | |
| lora_r=args.lora_r, | |
| fix_ffn=args.fix_ffn, | |
| add_visual_token=args.add_visual_token, | |
| add_box=args.add_box, | |
| add_pe=args.add_pe, | |
| add_relation=args.relation, | |
| use_format_v2=args.use_format_v2, | |
| use_sam=args.use_sam, | |
| enhance_data=args.enhance_data, | |
| roi_align=args.roi_align, | |
| roi_output_size=args.roi_output_size, | |
| apply_mask=args.apply_mask, | |
| ) | |
| if args.reset_llm: | |
| llm_state_dict = model.lang_encoder.state_dict() | |
| if args.rank == 0: | |
| print(args) | |
| print(image_processor) | |
| random_seed(args.seed, args.rank) | |
| if args.rank == 0 and args.report_to_wandb: | |
| wandb.init( | |
| project=args.wandb_project, | |
| entity=args.wandb_entity, | |
| name=args.run_name, | |
| config=vars(args), | |
| ) | |
| device_id = args.rank % torch.cuda.device_count() | |
| if args.ddp: | |
| print("use ddp mode") | |
| model = model.to(device_id) | |
| model = DDP(model) | |
| else: | |
| fpSixteen = MixedPrecision( | |
| param_dtype=torch.float16, | |
| # Gradient communication precision. | |
| reduce_dtype=torch.float16, | |
| # Buffer precision. | |
| # buffer_dtype=torch.float16, | |
| ) | |
| # from transformers.models.opt.modeling_opt import OPTDecoderLayer | |
| from open_clip.transformer import ResidualAttentionBlock | |
| from open_flamingo.src.flamingo_lm import FlamingoLayer | |
| from transformers.models.opt.modeling_opt import OPTDecoderLayer, OPTAttention | |
| from segment_anything.modeling.image_encoder import Block | |
| transformer_layer_cls=[ | |
| FlamingoLayer, | |
| ResidualAttentionBlock, | |
| Block, | |
| ] | |
| if args.fix_ffn: | |
| transformer_layer_cls.append(OPTAttention) | |
| auto_wrap_policy = functools.partial( | |
| transformer_auto_wrap_policy, | |
| transformer_layer_cls=transformer_layer_cls, | |
| ) | |
| if args.lora: | |
| from torch.distributed.fsdp.wrap import _or_policy | |
| lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) | |
| auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, auto_wrap_policy]) | |
| ignored_modules = [model.vision_encoder] | |
| # ignored_modules = None | |
| else: | |
| ignored_modules = [model.detection_head] | |
| # ignored_modules = None | |
| if args.add_pe: | |
| ignored_modules += [model.pos_enc] | |
| # if args.use_format_v2: | |
| # ignored_modules += [model.lang_encoder.visual_guided_lm_head] | |
| model = FSDP( | |
| model, | |
| auto_wrap_policy=auto_wrap_policy, | |
| mixed_precision=fpSixteen, | |
| device_id=torch.cuda.current_device(), | |
| ignored_modules=ignored_modules, | |
| sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, | |
| ) | |
| model = model.to(device_id) | |
| pile_dataset = None | |
| if args.instruct: | |
| laion_dataset = get_data(args, image_processor, tokenizer, "instruct") | |
| else: | |
| laion_dataset = get_data(args, image_processor, tokenizer, "ground_image_text") | |
| if args.pile_shards is not None: | |
| pile_dataset = get_data(args, image_processor, tokenizer, "pile") | |
| optim_groups = get_grouped_params(model, args) | |
| # optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) | |
| if args.ddp: | |
| optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) | |
| # optimizer = ZeroRedundancyOptimizer( | |
| # optim_groups, | |
| # optimizer_class=torch.optim.AdamW, | |
| # lr=args.learning_rate, | |
| # parameters_as_bucket_view=True, | |
| # ) | |
| else: | |
| if args.optimizer == "adamw": | |
| print("use adamw") | |
| optimizer = torch.optim.AdamW(optim_groups, lr=args.learning_rate) | |
| elif args.optimizer == "sgd": | |
| print("use sgd...") | |
| optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate) | |
| else: | |
| raise NotImplementedError | |
| total_training_steps = args.num_steps | |
| if args.rank == 0: | |
| logging.info(f"Total training steps: {total_training_steps}") | |
| if args.lr_scheduler == "linear": | |
| lr_scheduler = get_linear_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps, | |
| ) | |
| elif args.lr_scheduler == "cosine": | |
| lr_scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=total_training_steps, | |
| ) | |
| else: | |
| lr_scheduler = get_constant_schedule_with_warmup( | |
| optimizer, num_warmup_steps=args.warmup_steps | |
| ) | |
| if args.ddp: | |
| scaler = GradScaler() | |
| else: | |
| scaler = ShardedGradScaler() | |
| total_laion_token = 0 | |
| total_pile_token = 0 | |
| total_laion_sample = 0 | |
| total_step = 0 | |
| # check if a checkpoint exists for this run | |
| if os.path.exists(f"{args.run_name}"): | |
| checkpoint_list = glob.glob(f"{args.run_name}/checkpoint_*.pt") | |
| if len(checkpoint_list) == 0: | |
| if args.rank == 0: | |
| logging.info(f"Found no checkpoints for run {args.run_name}.") | |
| else: | |
| args.resume_from_checkpoint = sorted( | |
| checkpoint_list, key=lambda x: int(x.split("_")[-1].split(".")[0]) | |
| )[-1] | |
| if args.rank == 0: | |
| logging.info(f"Found checkpoint {args.resume_from_checkpoint} for run {args.run_name}.") | |
| args.restart = False | |
| if args.rank == 0: | |
| logging.info("do not restart because an existed checkpoint is found") | |
| if args.resume_from_checkpoint is not None: | |
| if args.rank == 0: | |
| logging.info(f"Loading checkpoint from {args.resume_from_checkpoint}") | |
| checkpoint = torch.load(args.resume_from_checkpoint, map_location="cpu") | |
| torch.distributed.barrier() | |
| if args.ddp: | |
| model.module.load_state_dict(checkpoint["model_state_dict"], strict=False) | |
| # sharded_osd = checkpoint['optimizer_state_dict'] | |
| else: | |
| with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT): | |
| if args.reset_llm: | |
| for key in checkpoint["model_state_dict"]: | |
| if key.startswith("lang_encoder"): | |
| if args.rank == 0: | |
| logging.info(f"reset {key}") | |
| llm_key = key.replace("lang_encoder.", "") | |
| checkpoint["model_state_dict"][key] = llm_state_dict[llm_key] | |
| model_state_dict = model.state_dict() | |
| for key in checkpoint["model_state_dict"].keys(): | |
| if model_state_dict[key].shape != checkpoint["model_state_dict"][key].shape: | |
| if args.rank == 0: | |
| logging.info(f'{key}: shape mismatched! {model_state_dict[key].shape} vs {checkpoint["model_state_dict"][key].shape}') | |
| checkpoint["model_state_dict"][key] = model_state_dict[key].clone() | |
| del model_state_dict | |
| model.load_state_dict(checkpoint["model_state_dict"], False) | |
| # sharded_osd = FSDP.shard_full_optim_state_dict(checkpoint['optimizer_state_dict'], model, optim_input=optim_groups) | |
| if not args.restart: | |
| # optimizer.load_state_dict(sharded_osd) | |
| lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) | |
| # scaler.load_state_dict(checkpoint["scaler_state_dict"]) | |
| total_laion_token = checkpoint.get("total_laion_token", 0) | |
| total_pile_token = checkpoint.get("total_pile_token", 0) | |
| total_laion_sample = checkpoint.get("total_laion_sample", 0) | |
| total_step = checkpoint.get("total_step", 0) | |
| if args.rank == 0: | |
| logging.info("load training statistics...") | |
| else: | |
| if args.rank == 0: | |
| logging.info("restart training / finetuning. only load model weight...") | |
| del checkpoint | |
| if args.reset_llm: | |
| del llm_state_dict | |
| torch.cuda.empty_cache() | |
| torch.distributed.barrier() | |
| model.train() | |
| if args.rank == 0: | |
| if not os.path.exists(args.run_name): | |
| os.makedirs(args.run_name) | |
| writer = SummaryWriter(log_dir=os.path.join(args.run_name, "tblog")) | |
| else: | |
| writer = None | |
| laion_dataset.set_epoch(total_step) | |
| laion_loader = laion_dataset.dataloader | |
| if pile_dataset is not None: | |
| pile_dataset.set_epoch(total_step) | |
| pile_loader = pile_dataset.dataloader | |
| else: | |
| pile_loader = FakeDataloader() | |
| train_one_epoch( | |
| args=args, | |
| model=model, | |
| tokenizer=tokenizer, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| laion_loader=laion_loader, | |
| pile_loader=pile_loader, | |
| device_id=device_id, | |
| writer=writer, | |
| scaler=scaler, | |
| optim_groups=optim_groups, | |
| total_laion_token=total_laion_token, | |
| total_pile_token=total_pile_token, | |
| total_laion_sample=total_laion_sample, | |
| total_step=total_step, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |