Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	added files
Browse files- sample_level_encoding.py +274 -0
    	
        sample_level_encoding.py
    ADDED
    
    | @@ -0,0 +1,274 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import argparse, os, sys, glob
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import pickle
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from omegaconf import OmegaConf
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
            from tqdm import tqdm, trange
         | 
| 8 | 
            +
            from einops import rearrange
         | 
| 9 | 
            +
            from torchvision.utils import make_grid
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            from ldm.util import instantiate_from_config
         | 
| 12 | 
            +
            from ldm.models.diffusion.ddim import DDIMSampler
         | 
| 13 | 
            +
            from ldm.models.diffusion.plms import PLMSSampler
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            def load_model_from_config(config, ckpt, verbose=False):
         | 
| 18 | 
            +
                print(f"Loading model from {ckpt}")
         | 
| 19 | 
            +
                # pl_sd = torch.load(ckpt, map_location="cpu")
         | 
| 20 | 
            +
                pl_sd = torch.load(ckpt)#, map_location="cpu")
         | 
| 21 | 
            +
                sd = pl_sd["state_dict"]
         | 
| 22 | 
            +
                model = instantiate_from_config(config.model)
         | 
| 23 | 
            +
                m, u = model.load_state_dict(sd, strict=False)
         | 
| 24 | 
            +
                if len(m) > 0 and verbose:
         | 
| 25 | 
            +
                    print("missing keys:")
         | 
| 26 | 
            +
                    print(m)
         | 
| 27 | 
            +
                if len(u) > 0 and verbose:
         | 
| 28 | 
            +
                    print("unexpected keys:")
         | 
| 29 | 
            +
                    print(u)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                model.cuda()
         | 
| 32 | 
            +
                model.eval()
         | 
| 33 | 
            +
                return model
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            if __name__ == "__main__":
         | 
| 37 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                parser.add_argument(
         | 
| 40 | 
            +
                    "--prompt",
         | 
| 41 | 
            +
                    type=str,
         | 
| 42 | 
            +
                    nargs="?",
         | 
| 43 | 
            +
                    default="a painting of a virus monster playing guitar",
         | 
| 44 | 
            +
                    help="the prompt to render"
         | 
| 45 | 
            +
                )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                parser.add_argument(
         | 
| 48 | 
            +
                    "--outdir",
         | 
| 49 | 
            +
                    type=str,
         | 
| 50 | 
            +
                    nargs="?",
         | 
| 51 | 
            +
                    help="dir to write results to",
         | 
| 52 | 
            +
                    default="outputs/txt2img-samples"
         | 
| 53 | 
            +
                )
         | 
| 54 | 
            +
                parser.add_argument(
         | 
| 55 | 
            +
                    "--ddim_steps",
         | 
| 56 | 
            +
                    type=int,
         | 
| 57 | 
            +
                    default=200,
         | 
| 58 | 
            +
                    help="number of ddim sampling steps",
         | 
| 59 | 
            +
                )
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                parser.add_argument(
         | 
| 62 | 
            +
                    "--plms",
         | 
| 63 | 
            +
                    action='store_true',
         | 
| 64 | 
            +
                    help="use plms sampling",
         | 
| 65 | 
            +
                )
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                parser.add_argument(
         | 
| 68 | 
            +
                    "--ddim_eta",
         | 
| 69 | 
            +
                    type=float,
         | 
| 70 | 
            +
                    default=1.0,
         | 
| 71 | 
            +
                    help="ddim eta (eta=0.0 corresponds to deterministic sampling",
         | 
| 72 | 
            +
                )
         | 
| 73 | 
            +
                parser.add_argument(
         | 
| 74 | 
            +
                    "--n_iter",
         | 
| 75 | 
            +
                    type=int,
         | 
| 76 | 
            +
                    default=1,
         | 
| 77 | 
            +
                    help="sample this often",
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                parser.add_argument(
         | 
| 81 | 
            +
                    "--H",
         | 
| 82 | 
            +
                    type=int,
         | 
| 83 | 
            +
                    default=256,
         | 
| 84 | 
            +
                    help="image height, in pixel space",
         | 
| 85 | 
            +
                )
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                parser.add_argument(
         | 
| 88 | 
            +
                    "--W",
         | 
| 89 | 
            +
                    type=int,
         | 
| 90 | 
            +
                    default=256,
         | 
| 91 | 
            +
                    help="image width, in pixel space",
         | 
| 92 | 
            +
                )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                parser.add_argument(
         | 
| 95 | 
            +
                    "--n_samples",
         | 
| 96 | 
            +
                    type=int,
         | 
| 97 | 
            +
                    default=4,
         | 
| 98 | 
            +
                    help="how many samples to produce for the given prompt",
         | 
| 99 | 
            +
                )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                parser.add_argument(
         | 
| 102 | 
            +
                    "--output_dir_name",
         | 
| 103 | 
            +
                    type=str,
         | 
| 104 | 
            +
                    default='default_file',
         | 
| 105 | 
            +
                    help="name of folder",
         | 
| 106 | 
            +
                )
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                parser.add_argument(
         | 
| 109 | 
            +
                    "--postfix",
         | 
| 110 | 
            +
                    type=str,
         | 
| 111 | 
            +
                    default='',
         | 
| 112 | 
            +
                    help="name of folder",
         | 
| 113 | 
            +
                )
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                parser.add_argument(
         | 
| 116 | 
            +
                    "--scale",
         | 
| 117 | 
            +
                    type=float,
         | 
| 118 | 
            +
                    # default=5.0,
         | 
| 119 | 
            +
                    default=1.0,
         | 
| 120 | 
            +
                    help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
         | 
| 121 | 
            +
                )
         | 
| 122 | 
            +
                opt = parser.parse_args()
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                # --scale 1.0 --n_samples 3 --ddim_steps 20
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                # # #### CLIP f4
         | 
| 127 | 
            +
                # config_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-34-23_CLIP_f4_maxlen77_classname/configs/2023-11-09T15-34-23-project.yaml'
         | 
| 128 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-34-23_CLIP_f4_maxlen77_classname/checkpoints/epoch=000158.ckpt'
         | 
| 129 | 
            +
                
         | 
| 130 | 
            +
                # # #### CLIP f8
         | 
| 131 | 
            +
                # config_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-30-05_CLIP_f8_maxlen77_classname/configs/2023-11-09T15-30-05-project.yaml'
         | 
| 132 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/clip/2023-11-09T15-30-05_CLIP_f8_maxlen77_classname/checkpoints/epoch=000119.ckpt'
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                #### Label Encoding
         | 
| 135 | 
            +
                # config_path = '/globalscratch/mridul/ldm/test/test_bert/2023-11-13T23-08-55_TEST_f4_ancestral_label_encoding/configs/2023-11-13T23-08-55-project.yaml'
         | 
| 136 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/test/test_bert/2023-11-13T23-08-55_TEST_f4_ancestral_label_encoding/checkpoints/epoch=000119.ckpt'
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                 #### Label Encoding Leave one out
         | 
| 139 | 
            +
                # config_path = '/globalscratch/mridul/ldm/level_encoding/leave_out/2023-12-01T01-49-15_HLE_f4_label_encoding_leave_out/configs/2023-12-01T01-49-15-project.yaml'
         | 
| 140 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/leave_out/2023-12-01T01-49-15_HLE_f4_label_encoding_leave_out/checkpoints/epoch=000131.ckpt'
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2023-12-03T09-33-45_HLE_f4_level_encoding_371/checkpoints/epoch=000119.ckpt'
         | 
| 143 | 
            +
                # config_path = '/globalscratch/mridul/ldm/level_encoding/2023-12-03T09-33-45_HLE_f4_level_encoding_371/configs/2023-12-03T09-33-45-project.yaml'
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
                # ### scale 1.25 - 137 epoch
         | 
| 147 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale1.25/checkpoints/epoch=000119.ckpt'
         | 
| 148 | 
            +
                # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale1.25/configs/2024-01-29T21-52-36-project.yaml'
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                ### scale 1.5 - 137 epoch
         | 
| 151 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-33-03_HLE_f4_scale1.5/checkpoints/epoch=000119.ckpt'
         | 
| 152 | 
            +
                # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-33-03_HLE_f4_scale1.5/configs/2024-01-29T20-33-03-project.yaml'
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
                # ### scale 2 - 137 epoch
         | 
| 156 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale2/checkpoints/epoch=000095.ckpt'
         | 
| 157 | 
            +
                # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T21-52-36_HLE_f4_scale2/configs/2024-01-29T21-52-36-project.yaml'
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # ### scale 5 - 137 epoch
         | 
| 160 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-32_HLE_f4_scale5/checkpoints/epoch=000095.ckpt'
         | 
| 161 | 
            +
                # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-32_HLE_f4_scale5/configs/2024-01-29T20-26-32-project.yaml'
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                # ### scale 10 - 137 epoch
         | 
| 164 | 
            +
                # ckpt_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-02_HLE_f4_scale10/checkpoints/epoch=000101.ckpt'
         | 
| 165 | 
            +
                # config_path = '/globalscratch/mridul/ldm/level_encoding/2024-01-29T20-26-02_HLE_f4_scale10/configs/2024-01-29T20-26-02-project.yaml'
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                ###### hle 371, 
         | 
| 168 | 
            +
                ckpt_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/checkpoints/epoch=000119.ckpt'
         | 
| 169 | 
            +
                config_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/configs/2024-03-01T23-15-36-project.yaml'
         | 
| 170 | 
            +
             | 
| 171 | 
            +
             | 
| 172 | 
            +
                label_to_class_mapping = {0: 'Alosa-chrysochloris', 1: 'Carassius-auratus', 2: 'Cyprinus-carpio', 3: 'Esox-americanus', 
         | 
| 173 | 
            +
                4: 'Gambusia-affinis', 5: 'Lepisosteus-osseus', 6: 'Lepisosteus-platostomus', 7: 'Lepomis-auritus', 8: 'Lepomis-cyanellus', 
         | 
| 174 | 
            +
                9: 'Lepomis-gibbosus', 10: 'Lepomis-gulosus', 11: 'Lepomis-humilis', 12: 'Lepomis-macrochirus', 13: 'Lepomis-megalotis', 
         | 
| 175 | 
            +
                14: 'Lepomis-microlophus', 15: 'Morone-chrysops', 16: 'Morone-mississippiensis', 17: 'Notropis-atherinoides', 
         | 
| 176 | 
            +
                18: 'Notropis-blennius', 19: 'Notropis-boops', 20: 'Notropis-buccatus', 21: 'Notropis-buchanani', 22: 'Notropis-dorsalis', 
         | 
| 177 | 
            +
                23: 'Notropis-hudsonius', 24: 'Notropis-leuciodus', 25: 'Notropis-nubilus', 26: 'Notropis-percobromus', 
         | 
| 178 | 
            +
                27: 'Notropis-stramineus', 28: 'Notropis-telescopus', 29: 'Notropis-texanus', 30: 'Notropis-volucellus', 
         | 
| 179 | 
            +
                31: 'Notropis-wickliffi', 32: 'Noturus-exilis', 33: 'Noturus-flavus', 34: 'Noturus-gyrinus', 35: 'Noturus-miurus', 
         | 
| 180 | 
            +
                36: 'Noturus-nocturnus', 37: 'Phenacobius-mirabilis'}
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                def get_label_from_class(class_name):
         | 
| 183 | 
            +
                    for key, value in label_to_class_mapping.items():
         | 
| 184 | 
            +
                        if value == class_name:
         | 
| 185 | 
            +
                            return key
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                config = OmegaConf.load(config_path)  # TODO: Optionally download from same location as ckpt and chnage this logic
         | 
| 188 | 
            +
                model = load_model_from_config(config, ckpt_path)  # TODO: check path
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
         | 
| 191 | 
            +
                model = model.to(device)
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                if opt.plms:
         | 
| 194 | 
            +
                    sampler = PLMSSampler(model)
         | 
| 195 | 
            +
                else:
         | 
| 196 | 
            +
                    sampler = DDIMSampler(model)
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                os.makedirs(opt.outdir, exist_ok=True)
         | 
| 199 | 
            +
                outpath = opt.outdir
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                prompt = opt.prompt
         | 
| 202 | 
            +
                all_images = []
         | 
| 203 | 
            +
                labels = []
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                class_to_node = '/fastscratch/mridul/fishes/class_to_ancestral_label.pkl'
         | 
| 206 | 
            +
                with open(class_to_node, 'rb') as pickle_file:
         | 
| 207 | 
            +
                    class_to_node_dict = pickle.load(pickle_file)
         | 
| 208 | 
            +
             | 
| 209 | 
            +
                sample_path = os.path.join(outpath, opt.output_dir_name)
         | 
| 210 | 
            +
                os.makedirs(sample_path, exist_ok=True)
         | 
| 211 | 
            +
                base_count = len(os.listdir(sample_path))
         | 
| 212 | 
            +
             | 
| 213 | 
            +
                for class_name, node_representation in tqdm(class_to_node_dict.items()):
         | 
| 214 | 
            +
                    prompt = node_representation
         | 
| 215 | 
            +
                    all_samples=list()
         | 
| 216 | 
            +
                    with torch.no_grad():
         | 
| 217 | 
            +
                        with model.ema_scope():
         | 
| 218 | 
            +
                            uc = None
         | 
| 219 | 
            +
                            # if opt.scale != 1.0:
         | 
| 220 | 
            +
                            #     uc = model.get_learned_conditioning(opt.n_samples * [""])
         | 
| 221 | 
            +
                            for n in trange(opt.n_iter, desc="Sampling"):
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                                all_prompts = opt.n_samples * (prompt)
         | 
| 224 | 
            +
                                all_prompts = [tuple(all_prompts)]
         | 
| 225 | 
            +
                                print(class_name, prompt)
         | 
| 226 | 
            +
                                c = model.get_learned_conditioning({'class_to_node': all_prompts})
         | 
| 227 | 
            +
                                shape = [3, 64, 64]
         | 
| 228 | 
            +
                                samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
         | 
| 229 | 
            +
                                                                conditioning=c,
         | 
| 230 | 
            +
                                                                batch_size=opt.n_samples,
         | 
| 231 | 
            +
                                                                shape=shape,
         | 
| 232 | 
            +
                                                                verbose=False,
         | 
| 233 | 
            +
                                                                unconditional_guidance_scale=opt.scale,
         | 
| 234 | 
            +
                                                                unconditional_conditioning=uc,
         | 
| 235 | 
            +
                                                                eta=opt.ddim_eta)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
                                x_samples_ddim = model.decode_first_stage(samples_ddim)
         | 
| 238 | 
            +
                                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                                all_samples.append(x_samples_ddim)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                    ###### to make grid
         | 
| 243 | 
            +
                    # additionally, save as grid
         | 
| 244 | 
            +
                    grid = torch.stack(all_samples, 0)
         | 
| 245 | 
            +
                    grid = rearrange(grid, 'n b c h w -> (n b) c h w')
         | 
| 246 | 
            +
                    grid = make_grid(grid, nrow=opt.n_samples)
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                    # to image
         | 
| 249 | 
            +
                    grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
         | 
| 250 | 
            +
                    Image.fromarray(grid.astype(np.uint8)).save(os.path.join(sample_path, f'{class_name.replace(" ", "-")}.png'))
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                #     # individual images
         | 
| 253 | 
            +
                #     grid = torch.stack(all_samples, 0)
         | 
| 254 | 
            +
                #     grid = rearrange(grid, 'n b c h w -> (n b) c h w')
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                #     for i in range(opt.n_samples):
         | 
| 257 | 
            +
                #         sample = grid[i]
         | 
| 258 | 
            +
                #         img = 255. * rearrange(sample, 'c h w -> h w c').cpu().numpy()
         | 
| 259 | 
            +
                #         img_arr = img.astype(np.uint8)
         | 
| 260 | 
            +
                #         class_name = class_name.replace(" ", "-")
         | 
| 261 | 
            +
                #         all_images.append(img_arr)
         | 
| 262 | 
            +
                #         labels.append(get_label_from_class(class_name))
         | 
| 263 | 
            +
                #         Image.fromarray(img_arr).save(f'{sample_path}/{class_name}_{i}.png')
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                # all_images = np.array(all_images)
         | 
| 266 | 
            +
                # labels = np.array(labels)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                # np.savez(sample_path + '.npz', all_images, labels)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
             | 
| 271 | 
            +
                print(f"Your samples are ready and waiting four you here: \n{sample_path} \nEnjoy.")
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            # python sample_text.py --outdir /home/mridul/sample_images_text --scale 1.0 --n_samples 3 --ddim_steps 200 --ddim_eta 1.0
         |