Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload folder using huggingface_hub
Browse files- .ipynb_checkpoints/hf_demo_test-checkpoint.ipynb +336 -0
- README.md +3 -10
- __pycache__/inference.cpython-39.pyc +0 -0
- custom_datasets/__init__.py +141 -0
- custom_datasets/__pycache__/__init__.cpython-39.pyc +0 -0
- custom_datasets/__pycache__/coco.cpython-39.pyc +0 -0
- custom_datasets/__pycache__/imagepair.cpython-39.pyc +0 -0
- custom_datasets/__pycache__/mypath.cpython-39.pyc +0 -0
- custom_datasets/coco.py +307 -0
- custom_datasets/custom_caption.py +113 -0
- custom_datasets/filt/coco/filt.py +186 -0
- custom_datasets/filt/sam_filt.py +299 -0
- custom_datasets/imagepair.py +240 -0
- custom_datasets/lhq.py +127 -0
- custom_datasets/mypath.py +29 -0
- custom_datasets/sam.py +160 -0
- data/Art_adapters/albert-gleizes_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/andre-derain_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/andy_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/camille-corot_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/gerhard-richter_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/henri-matisse_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/jackson-pollock_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/joan-miro_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/kandinsky_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/katsushika-hokusai_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/klimt_subset3/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/m.c.-escher_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/monet_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/picasso_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/roy-lichtenstein_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/van_gogh_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/Art_adapters/walter-battiss_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt +3 -0
- data/unsafe.png +0 -0
- hf_demo.py +147 -0
- hf_demo_test.ipynb +336 -0
- inference.py +657 -0
- utils/__init__.py +1 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/lora.cpython-39.pyc +0 -0
- utils/__pycache__/metrics.cpython-39.pyc +0 -0
- utils/__pycache__/train_util.cpython-39.pyc +0 -0
- utils/art_filter.py +210 -0
- utils/config_util.py +105 -0
- utils/debug_util.py +16 -0
- utils/lora.py +282 -0
- utils/metrics.py +577 -0
- utils/model_util.py +291 -0
- utils/prompt_util.py +174 -0
- utils/train_util.py +526 -0
    	
        .ipynb_checkpoints/hf_demo_test-checkpoint.ipynb
    ADDED
    
    | @@ -0,0 +1,336 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 1,
         | 
| 6 | 
            +
               "id": "initial_id",
         | 
| 7 | 
            +
               "metadata": {
         | 
| 8 | 
            +
                "ExecuteTime": {
         | 
| 9 | 
            +
                 "end_time": "2024-12-09T09:44:30.641366Z",
         | 
| 10 | 
            +
                 "start_time": "2024-12-09T09:44:11.789050Z"
         | 
| 11 | 
            +
                }
         | 
| 12 | 
            +
               },
         | 
| 13 | 
            +
               "outputs": [],
         | 
| 14 | 
            +
               "source": [
         | 
| 15 | 
            +
                "import os\n",
         | 
| 16 | 
            +
                "\n",
         | 
| 17 | 
            +
                "import gradio as gr\n",
         | 
| 18 | 
            +
                "from diffusers import DiffusionPipeline\n",
         | 
| 19 | 
            +
                "import matplotlib.pyplot as plt\n",
         | 
| 20 | 
            +
                "import torch\n",
         | 
| 21 | 
            +
                "from PIL import Image\n"
         | 
| 22 | 
            +
               ]
         | 
| 23 | 
            +
              },
         | 
| 24 | 
            +
              {
         | 
| 25 | 
            +
               "cell_type": "code",
         | 
| 26 | 
            +
               "execution_count": 2,
         | 
| 27 | 
            +
               "id": "ddf33e0d3abacc2c",
         | 
| 28 | 
            +
               "metadata": {},
         | 
| 29 | 
            +
               "outputs": [],
         | 
| 30 | 
            +
               "source": [
         | 
| 31 | 
            +
                "import sys\n",
         | 
| 32 | 
            +
                "#append current path\n",
         | 
| 33 | 
            +
                "sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")"
         | 
| 34 | 
            +
               ]
         | 
| 35 | 
            +
              },
         | 
| 36 | 
            +
              {
         | 
| 37 | 
            +
               "cell_type": "code",
         | 
| 38 | 
            +
               "execution_count": 3,
         | 
| 39 | 
            +
               "id": "643e49fd601daf8f",
         | 
| 40 | 
            +
               "metadata": {
         | 
| 41 | 
            +
                "ExecuteTime": {
         | 
| 42 | 
            +
                 "end_time": "2024-12-09T09:44:35.790962Z",
         | 
| 43 | 
            +
                 "start_time": "2024-12-09T09:44:35.779496Z"
         | 
| 44 | 
            +
                }
         | 
| 45 | 
            +
               },
         | 
| 46 | 
            +
               "outputs": [],
         | 
| 47 | 
            +
               "source": [
         | 
| 48 | 
            +
                "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
         | 
| 49 | 
            +
               ]
         | 
| 50 | 
            +
              },
         | 
| 51 | 
            +
              {
         | 
| 52 | 
            +
               "cell_type": "code",
         | 
| 53 | 
            +
               "execution_count": 4,
         | 
| 54 | 
            +
               "id": "e03aae2a4e5676dd",
         | 
| 55 | 
            +
               "metadata": {
         | 
| 56 | 
            +
                "ExecuteTime": {
         | 
| 57 | 
            +
                 "end_time": "2024-12-09T09:44:44.157412Z",
         | 
| 58 | 
            +
                 "start_time": "2024-12-09T09:44:37.138452Z"
         | 
| 59 | 
            +
                }
         | 
| 60 | 
            +
               },
         | 
| 61 | 
            +
               "outputs": [
         | 
| 62 | 
            +
                {
         | 
| 63 | 
            +
                 "name": "stderr",
         | 
| 64 | 
            +
                 "output_type": "stream",
         | 
| 65 | 
            +
                 "text": [
         | 
| 66 | 
            +
                  "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
         | 
| 67 | 
            +
                  "  warnings.warn(\n"
         | 
| 68 | 
            +
                 ]
         | 
| 69 | 
            +
                },
         | 
| 70 | 
            +
                {
         | 
| 71 | 
            +
                 "data": {
         | 
| 72 | 
            +
                  "application/vnd.jupyter.widget-view+json": {
         | 
| 73 | 
            +
                   "model_id": "9df8347307674ba8afb0250e23109aa1",
         | 
| 74 | 
            +
                   "version_major": 2,
         | 
| 75 | 
            +
                   "version_minor": 0
         | 
| 76 | 
            +
                  },
         | 
| 77 | 
            +
                  "text/plain": [
         | 
| 78 | 
            +
                   "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
         | 
| 79 | 
            +
                  ]
         | 
| 80 | 
            +
                 },
         | 
| 81 | 
            +
                 "metadata": {},
         | 
| 82 | 
            +
                 "output_type": "display_data"
         | 
| 83 | 
            +
                }
         | 
| 84 | 
            +
               ],
         | 
| 85 | 
            +
               "source": [
         | 
| 86 | 
            +
                "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
         | 
| 87 | 
            +
                "device = \"cuda\""
         | 
| 88 | 
            +
               ]
         | 
| 89 | 
            +
              },
         | 
| 90 | 
            +
              {
         | 
| 91 | 
            +
               "cell_type": "code",
         | 
| 92 | 
            +
               "execution_count": 5,
         | 
| 93 | 
            +
               "id": "83916bc68ff5d914",
         | 
| 94 | 
            +
               "metadata": {
         | 
| 95 | 
            +
                "ExecuteTime": {
         | 
| 96 | 
            +
                 "end_time": "2024-12-09T09:44:52.694399Z",
         | 
| 97 | 
            +
                 "start_time": "2024-12-09T09:44:44.210695Z"
         | 
| 98 | 
            +
                }
         | 
| 99 | 
            +
               },
         | 
| 100 | 
            +
               "outputs": [],
         | 
| 101 | 
            +
               "source": [
         | 
| 102 | 
            +
                "from inference import get_lora_network, inference, get_validation_dataloader\n",
         | 
| 103 | 
            +
                "lora_map = {\n",
         | 
| 104 | 
            +
                "    \"None\": \"None\",\n",
         | 
| 105 | 
            +
                "    \"Andre Derain\": \"andre-derain_subset1\",\n",
         | 
| 106 | 
            +
                "    \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
         | 
| 107 | 
            +
                "    \"Andy Warhol\": \"andy_subset1\",\n",
         | 
| 108 | 
            +
                "    \"Walter Battiss\": \"walter-battiss_subset2\",\n",
         | 
| 109 | 
            +
                "    \"Camille Corot\": \"camille-corot_subset1\",\n",
         | 
| 110 | 
            +
                "    \"Claude Monet\": \"monet_subset2\",\n",
         | 
| 111 | 
            +
                "    \"Pablo Picasso\": \"picasso_subset1\",\n",
         | 
| 112 | 
            +
                "    \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
         | 
| 113 | 
            +
                "    \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
         | 
| 114 | 
            +
                "    \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
         | 
| 115 | 
            +
                "    \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
         | 
| 116 | 
            +
                "    \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
         | 
| 117 | 
            +
                "    \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
         | 
| 118 | 
            +
                "    \"Gustav Klimt\": \"klimt_subset3\",\n",
         | 
| 119 | 
            +
                "    \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
         | 
| 120 | 
            +
                "    \"Henri Matisse\": \"henri-matisse_subset1\",\n",
         | 
| 121 | 
            +
                "    \"Joan Miro\": \"joan-miro_subset2\",\n",
         | 
| 122 | 
            +
                "}\n",
         | 
| 123 | 
            +
                "\n",
         | 
| 124 | 
            +
                "def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
         | 
| 125 | 
            +
                "    adapter_path = lora_map[adapter_choice]\n",
         | 
| 126 | 
            +
                "    if adapter_path not in [None, \"None\"]:\n",
         | 
| 127 | 
            +
                "        adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
         | 
| 128 | 
            +
                "\n",
         | 
| 129 | 
            +
                "    prompts = [prompt]*samples\n",
         | 
| 130 | 
            +
                "    infer_loader = get_validation_dataloader(prompts)\n",
         | 
| 131 | 
            +
                "    network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
         | 
| 132 | 
            +
                "    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
         | 
| 133 | 
            +
                "                            height=512, width=512, scales=[1.0],\n",
         | 
| 134 | 
            +
                "                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
         | 
| 135 | 
            +
                "                            start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
         | 
| 136 | 
            +
                "                            from_scratch=True)[0][1.0]\n",
         | 
| 137 | 
            +
                "    return pred_images\n",
         | 
| 138 | 
            +
                "\n",
         | 
| 139 | 
            +
                "def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
         | 
| 140 | 
            +
                "    infer_loader = get_validation_dataloader(prompts, image)\n",
         | 
| 141 | 
            +
                "    network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
         | 
| 142 | 
            +
                "    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
         | 
| 143 | 
            +
                "                            height=512, width=512, scales=[0.,1.],\n",
         | 
| 144 | 
            +
                "                            save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
         | 
| 145 | 
            +
                "                            start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
         | 
| 146 | 
            +
                "                            from_scratch=False)\n",
         | 
| 147 | 
            +
                "    return pred_images\n",
         | 
| 148 | 
            +
                "\n",
         | 
| 149 | 
            +
                "# def infer(prompt, samples, steps, scale, seed):\n",
         | 
| 150 | 
            +
                "#     generator = torch.Generator(device=device).manual_seed(seed)\n",
         | 
| 151 | 
            +
                "#     images_list = pipe(  # type: ignore\n",
         | 
| 152 | 
            +
                "#         [prompt] * samples,\n",
         | 
| 153 | 
            +
                "#         num_inference_steps=steps,\n",
         | 
| 154 | 
            +
                "#         guidance_scale=scale,\n",
         | 
| 155 | 
            +
                "#         generator=generator,\n",
         | 
| 156 | 
            +
                "#     )\n",
         | 
| 157 | 
            +
                "#     images = []\n",
         | 
| 158 | 
            +
                "#     safe_image = Image.open(r\"data/unsafe.png\")\n",
         | 
| 159 | 
            +
                "#     print(images_list)\n",
         | 
| 160 | 
            +
                "#     for i, image in enumerate(images_list[\"images\"]):  # type: ignore\n",
         | 
| 161 | 
            +
                "#         if images_list[\"nsfw_content_detected\"][i]:  # type: ignore\n",
         | 
| 162 | 
            +
                "#             images.append(safe_image)\n",
         | 
| 163 | 
            +
                "#         else:\n",
         | 
| 164 | 
            +
                "#             images.append(image)\n",
         | 
| 165 | 
            +
                "#     return images\n"
         | 
| 166 | 
            +
               ]
         | 
| 167 | 
            +
              },
         | 
| 168 | 
            +
              {
         | 
| 169 | 
            +
               "cell_type": "code",
         | 
| 170 | 
            +
               "execution_count": 6,
         | 
| 171 | 
            +
               "id": "aa33e9d104023847",
         | 
| 172 | 
            +
               "metadata": {
         | 
| 173 | 
            +
                "ExecuteTime": {
         | 
| 174 | 
            +
                 "end_time": "2024-12-09T12:09:39.339583Z",
         | 
| 175 | 
            +
                 "start_time": "2024-12-09T12:09:38.953936Z"
         | 
| 176 | 
            +
                }
         | 
| 177 | 
            +
               },
         | 
| 178 | 
            +
               "outputs": [
         | 
| 179 | 
            +
                {
         | 
| 180 | 
            +
                 "name": "stdout",
         | 
| 181 | 
            +
                 "output_type": "stream",
         | 
| 182 | 
            +
                 "text": [
         | 
| 183 | 
            +
                  "<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
         | 
| 184 | 
            +
                  "Running on local URL:  http://127.0.0.1:7876\n",
         | 
| 185 | 
            +
                  "Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
         | 
| 186 | 
            +
                  "\n",
         | 
| 187 | 
            +
                  "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
         | 
| 188 | 
            +
                 ]
         | 
| 189 | 
            +
                },
         | 
| 190 | 
            +
                {
         | 
| 191 | 
            +
                 "data": {
         | 
| 192 | 
            +
                  "text/html": [
         | 
| 193 | 
            +
                   "<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
         | 
| 194 | 
            +
                  ],
         | 
| 195 | 
            +
                  "text/plain": [
         | 
| 196 | 
            +
                   "<IPython.core.display.HTML object>"
         | 
| 197 | 
            +
                  ]
         | 
| 198 | 
            +
                 },
         | 
| 199 | 
            +
                 "metadata": {},
         | 
| 200 | 
            +
                 "output_type": "display_data"
         | 
| 201 | 
            +
                },
         | 
| 202 | 
            +
                {
         | 
| 203 | 
            +
                 "data": {
         | 
| 204 | 
            +
                  "text/plain": []
         | 
| 205 | 
            +
                 },
         | 
| 206 | 
            +
                 "execution_count": 6,
         | 
| 207 | 
            +
                 "metadata": {},
         | 
| 208 | 
            +
                 "output_type": "execute_result"
         | 
| 209 | 
            +
                },
         | 
| 210 | 
            +
                {
         | 
| 211 | 
            +
                 "name": "stdout",
         | 
| 212 | 
            +
                 "output_type": "stream",
         | 
| 213 | 
            +
                 "text": [
         | 
| 214 | 
            +
                  "Train method: None\n",
         | 
| 215 | 
            +
                  "Rank: 1, Alpha: 1\n",
         | 
| 216 | 
            +
                  "create LoRA for U-Net: 0 modules.\n",
         | 
| 217 | 
            +
                  "save dir: None\n",
         | 
| 218 | 
            +
                  "['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
         | 
| 219 | 
            +
                 ]
         | 
| 220 | 
            +
                },
         | 
| 221 | 
            +
                {
         | 
| 222 | 
            +
                 "name": "stderr",
         | 
| 223 | 
            +
                 "output_type": "stream",
         | 
| 224 | 
            +
                 "text": [
         | 
| 225 | 
            +
                  "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
         | 
| 226 | 
            +
                  "  return F.conv2d(input, weight, bias, self.stride,\n",
         | 
| 227 | 
            +
                  "\n",
         | 
| 228 | 
            +
                  "00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00,  6.90it/s]"
         | 
| 229 | 
            +
                 ]
         | 
| 230 | 
            +
                },
         | 
| 231 | 
            +
                {
         | 
| 232 | 
            +
                 "name": "stdout",
         | 
| 233 | 
            +
                 "output_type": "stream",
         | 
| 234 | 
            +
                 "text": [
         | 
| 235 | 
            +
                  "Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
         | 
| 236 | 
            +
                 ]
         | 
| 237 | 
            +
                }
         | 
| 238 | 
            +
               ],
         | 
| 239 | 
            +
               "source": [
         | 
| 240 | 
            +
                "block = gr.Blocks()\n",
         | 
| 241 | 
            +
                "# Direct infer\n",
         | 
| 242 | 
            +
                "with block:\n",
         | 
| 243 | 
            +
                "    with gr.Group():\n",
         | 
| 244 | 
            +
                "        with gr.Row():\n",
         | 
| 245 | 
            +
                "            text = gr.Textbox(\n",
         | 
| 246 | 
            +
                "                label=\"Enter your prompt\",\n",
         | 
| 247 | 
            +
                "                max_lines=2,\n",
         | 
| 248 | 
            +
                "                placeholder=\"Enter your prompt\",\n",
         | 
| 249 | 
            +
                "                container=False,\n",
         | 
| 250 | 
            +
                "                value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
         | 
| 251 | 
            +
                "            )\n",
         | 
| 252 | 
            +
                "            \n",
         | 
| 253 | 
            +
                "\n",
         | 
| 254 | 
            +
                "            \n",
         | 
| 255 | 
            +
                "            btn = gr.Button(\"Run\", scale=0)\n",
         | 
| 256 | 
            +
                "        gallery = gr.Gallery(\n",
         | 
| 257 | 
            +
                "            label=\"Generated images\",\n",
         | 
| 258 | 
            +
                "            show_label=False,\n",
         | 
| 259 | 
            +
                "            elem_id=\"gallery\",\n",
         | 
| 260 | 
            +
                "            columns=[2],\n",
         | 
| 261 | 
            +
                "        )\n",
         | 
| 262 | 
            +
                "\n",
         | 
| 263 | 
            +
                "        advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
         | 
| 264 | 
            +
                "\n",
         | 
| 265 | 
            +
                "        with gr.Row(elem_id=\"advanced-options\"):\n",
         | 
| 266 | 
            +
                "            adapter_choice = gr.Dropdown(\n",
         | 
| 267 | 
            +
                "                label=\"Choose adapter\",\n",
         | 
| 268 | 
            +
                "                choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
         | 
| 269 | 
            +
                "                         \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
         | 
| 270 | 
            +
                "                         \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
         | 
| 271 | 
            +
                "                         \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
         | 
| 272 | 
            +
                "                         \"Henri Matisse\", \"Joan Miro\"\n",
         | 
| 273 | 
            +
                "                         ],\n",
         | 
| 274 | 
            +
                "                value=\"None\"\n",
         | 
| 275 | 
            +
                "            )\n",
         | 
| 276 | 
            +
                "            # print(adapter_choice[0])\n",
         | 
| 277 | 
            +
                "            # lora_path = lora_map[adapter_choice.value]\n",
         | 
| 278 | 
            +
                "            # if lora_path is not None:\n",
         | 
| 279 | 
            +
                "            #     lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
         | 
| 280 | 
            +
                "\n",
         | 
| 281 | 
            +
                "            samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
         | 
| 282 | 
            +
                "            steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
         | 
| 283 | 
            +
                "            scale = gr.Slider(\n",
         | 
| 284 | 
            +
                "                label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
         | 
| 285 | 
            +
                "            )\n",
         | 
| 286 | 
            +
                "            print(scale)\n",
         | 
| 287 | 
            +
                "            seed = gr.Slider(\n",
         | 
| 288 | 
            +
                "                label=\"Seed\",\n",
         | 
| 289 | 
            +
                "                minimum=0,\n",
         | 
| 290 | 
            +
                "                maximum=2147483647,\n",
         | 
| 291 | 
            +
                "                step=1,\n",
         | 
| 292 | 
            +
                "                randomize=True,\n",
         | 
| 293 | 
            +
                "            )\n",
         | 
| 294 | 
            +
                "\n",
         | 
| 295 | 
            +
                "        gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
         | 
| 296 | 
            +
                "        advanced_button.click(\n",
         | 
| 297 | 
            +
                "            None,\n",
         | 
| 298 | 
            +
                "            [],\n",
         | 
| 299 | 
            +
                "            text,\n",
         | 
| 300 | 
            +
                "        )\n",
         | 
| 301 | 
            +
                "\n",
         | 
| 302 | 
            +
                "\n",
         | 
| 303 | 
            +
                "block.launch(share=True)"
         | 
| 304 | 
            +
               ]
         | 
| 305 | 
            +
              },
         | 
| 306 | 
            +
              {
         | 
| 307 | 
            +
               "cell_type": "code",
         | 
| 308 | 
            +
               "execution_count": null,
         | 
| 309 | 
            +
               "id": "3239c12167a5f2cd",
         | 
| 310 | 
            +
               "metadata": {},
         | 
| 311 | 
            +
               "outputs": [],
         | 
| 312 | 
            +
               "source": []
         | 
| 313 | 
            +
              }
         | 
| 314 | 
            +
             ],
         | 
| 315 | 
            +
             "metadata": {
         | 
| 316 | 
            +
              "kernelspec": {
         | 
| 317 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 318 | 
            +
               "language": "python",
         | 
| 319 | 
            +
               "name": "python3"
         | 
| 320 | 
            +
              },
         | 
| 321 | 
            +
              "language_info": {
         | 
| 322 | 
            +
               "codemirror_mode": {
         | 
| 323 | 
            +
                "name": "ipython",
         | 
| 324 | 
            +
                "version": 3
         | 
| 325 | 
            +
               },
         | 
| 326 | 
            +
               "file_extension": ".py",
         | 
| 327 | 
            +
               "mimetype": "text/x-python",
         | 
| 328 | 
            +
               "name": "python",
         | 
| 329 | 
            +
               "nbconvert_exporter": "python",
         | 
| 330 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 331 | 
            +
               "version": "3.9.18"
         | 
| 332 | 
            +
              }
         | 
| 333 | 
            +
             },
         | 
| 334 | 
            +
             "nbformat": 4,
         | 
| 335 | 
            +
             "nbformat_minor": 5
         | 
| 336 | 
            +
            }
         | 
    	
        README.md
    CHANGED
    
    | @@ -1,13 +1,6 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title: Art | 
| 3 | 
            -
             | 
| 4 | 
            -
            colorFrom: purple
         | 
| 5 | 
            -
            colorTo: red
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
            -
            app_file: app.py
         | 
| 9 | 
            -
            pinned: false
         | 
| 10 | 
            -
            short_description: Demo for Art Free Diffusion
         | 
| 11 | 
             
            ---
         | 
| 12 | 
            -
             | 
| 13 | 
            -
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Art-Free-Diffusion
         | 
| 3 | 
            +
            app_file: hf_demo.py
         | 
|  | |
|  | |
| 4 | 
             
            sdk: gradio
         | 
| 5 | 
            +
            sdk_version: 4.44.1
         | 
|  | |
|  | |
|  | |
| 6 | 
             
            ---
         | 
|  | |
|  | 
    	
        __pycache__/inference.cpython-39.pyc
    ADDED
    
    | Binary file (19.8 kB). View file | 
|  | 
    	
        custom_datasets/__init__.py
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .mypath import MyPath
         | 
| 2 | 
            +
            from copy import deepcopy
         | 
| 3 | 
            +
            from datasets import load_dataset
         | 
| 4 | 
            +
            from torch.utils.data import Dataset
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            def get_dataset(dataset_name, transformation=None , train_subsample:int =None, val_subsample:int = 10000, get_val=True):
         | 
| 8 | 
            +
                if train_subsample is not None and train_subsample<val_subsample and train_subsample!=-1:
         | 
| 9 | 
            +
                    print(f"Warning: train_subsample is smaller than val_subsample. val_subsample will be set to train_subsample: {train_subsample}")
         | 
| 10 | 
            +
                    val_subsample = train_subsample
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                if dataset_name == "imagenet":
         | 
| 13 | 
            +
                    from .imagenet import Imagenet1k
         | 
| 14 | 
            +
                    train_set = Imagenet1k(data_dir = MyPath.db_root_dir(dataset_name), transform = transformation, split="train", prompt_transform=Label_prompt_transform(real=True))
         | 
| 15 | 
            +
                elif dataset_name == "coco_train":
         | 
| 16 | 
            +
                    # raise NotImplementedError("Use coco_filtered instead")
         | 
| 17 | 
            +
                    from .coco import CocoCaptions
         | 
| 18 | 
            +
                    train_set = CocoCaptions(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
         | 
| 19 | 
            +
                elif dataset_name == "coco_val":
         | 
| 20 | 
            +
                    from .coco import CocoCaptions
         | 
| 21 | 
            +
                    train_set = CocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"))
         | 
| 22 | 
            +
                    return {"val": train_set}
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                elif dataset_name == "coco_clip_filtered":
         | 
| 25 | 
            +
                    from .coco import CocoCaptions_clip_filtered
         | 
| 26 | 
            +
                    train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"))
         | 
| 27 | 
            +
                elif dataset_name == "coco_filtered_sub100":
         | 
| 28 | 
            +
                    from .coco import CocoCaptions_clip_filtered
         | 
| 29 | 
            +
                    train_set = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"), id_file=MyPath.db_root_dir("coco_clip_filtered_ids_sub100"),)
         | 
| 30 | 
            +
                elif dataset_name == "cifar10":
         | 
| 31 | 
            +
                    from .cifar import CIFAR10
         | 
| 32 | 
            +
                    train_set = CIFAR10(root=MyPath.db_root_dir("cifar10"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
         | 
| 33 | 
            +
                elif dataset_name == "cifar100":
         | 
| 34 | 
            +
                    from .cifar import CIFAR100
         | 
| 35 | 
            +
                    train_set = CIFAR100(root=MyPath.db_root_dir("cifar100"), train=True, transform=transformation, prompt_transform=Label_prompt_transform(real=True))
         | 
| 36 | 
            +
                elif "wikiart" in dataset_name and "/" not in dataset_name:
         | 
| 37 | 
            +
                    from .wikiart.wikiart import Wikiart_caption
         | 
| 38 | 
            +
                    dataset = Wikiart_caption(data_path=MyPath.db_root_dir(dataset_name))
         | 
| 39 | 
            +
                    return {"train": dataset.subsample(train_subsample).get_dataset(), "val": deepcopy(dataset).subsample(val_subsample).get_dataset() if get_val else None}
         | 
| 40 | 
            +
                elif "imagepair" in dataset_name:
         | 
| 41 | 
            +
                    from .imagepair import ImagePair
         | 
| 42 | 
            +
                    train_set = ImagePair(folder1=MyPath.db_root_dir(dataset_name)[0], folder2=MyPath.db_root_dir(dataset_name)[1], transform=transformation).subsample(train_subsample)
         | 
| 43 | 
            +
                # elif dataset_name == "sam_clip_filtered":
         | 
| 44 | 
            +
                #     from .sam import SamDataset
         | 
| 45 | 
            +
                #     train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_ids"), transforms=transformation).subsample(train_subsample)
         | 
| 46 | 
            +
                elif dataset_name == "sam_whole_filtered":
         | 
| 47 | 
            +
                    from .sam import SamDataset
         | 
| 48 | 
            +
                    train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
         | 
| 49 | 
            +
                elif dataset_name == "sam_whole_filtered_val":
         | 
| 50 | 
            +
                    from .sam import SamDataset
         | 
| 51 | 
            +
                    train_set = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_val"), id_dict_file=MyPath.db_root_dir("sam_id_dict"), transforms=transformation).subsample(train_subsample)
         | 
| 52 | 
            +
                    return {"val": train_set}
         | 
| 53 | 
            +
                elif dataset_name == "lhq_sub100":
         | 
| 54 | 
            +
                    from .lhq import LhqDataset
         | 
| 55 | 
            +
                    train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub100"), transforms=transformation)
         | 
| 56 | 
            +
                elif dataset_name == "lhq_sub500":
         | 
| 57 | 
            +
                    from .lhq import LhqDataset
         | 
| 58 | 
            +
                    train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub500"), transforms=transformation)
         | 
| 59 | 
            +
                elif dataset_name == "lhq_sub9":
         | 
| 60 | 
            +
                    from .lhq import LhqDataset
         | 
| 61 | 
            +
                    train_set = LhqDataset(image_folder_path=MyPath.db_root_dir("lhq_images"), caption_folder_path=MyPath.db_root_dir("lhq_captions"), id_file=MyPath.db_root_dir("lhq_ids_sub9"), transforms=transformation)
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                elif dataset_name == "custom_coco100":
         | 
| 64 | 
            +
                    from .coco import CustomCocoCaptions
         | 
| 65 | 
            +
                    train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
         | 
| 66 | 
            +
                                       custom_file=MyPath.db_root_dir("custom_coco100_captions"), transforms=transformation)
         | 
| 67 | 
            +
                elif dataset_name == "custom_coco500":
         | 
| 68 | 
            +
                    from .coco import CustomCocoCaptions
         | 
| 69 | 
            +
                    train_set = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
         | 
| 70 | 
            +
                                       custom_file=MyPath.db_root_dir("custom_coco500_captions"), transforms=transformation)
         | 
| 71 | 
            +
                elif dataset_name == "laion_pop500":
         | 
| 72 | 
            +
                    from .custom_caption import Laion_pop
         | 
| 73 | 
            +
                    train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                elif dataset_name == "laion_pop500_first_sentence":
         | 
| 76 | 
            +
                    from .custom_caption import Laion_pop
         | 
| 77 | 
            +
                    train_set = Laion_pop(anno_file=MyPath.db_root_dir("laion_pop500_first_sentence"), image_root=MyPath.db_root_dir("laion_images"), transform=transformation)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
             | 
| 80 | 
            +
                else:
         | 
| 81 | 
            +
                    try:
         | 
| 82 | 
            +
                        train_set = load_dataset('imagefolder', data_dir = dataset_name, split="train")
         | 
| 83 | 
            +
                        val_set = deepcopy(train_set)
         | 
| 84 | 
            +
                        if val_subsample is not None and val_subsample != -1:
         | 
| 85 | 
            +
                            val_set = val_set.shuffle(seed=0).select(range(val_subsample))
         | 
| 86 | 
            +
                        return {"train": train_set, "val": val_set if get_val else None}
         | 
| 87 | 
            +
                    except:
         | 
| 88 | 
            +
                        raise ValueError(f"dataset_name {dataset_name} not found.")
         | 
| 89 | 
            +
                return {"train": train_set, "val": deepcopy(train_set).subsample(val_subsample) if get_val else None}
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            class MergeDataset(Dataset):
         | 
| 93 | 
            +
                @staticmethod
         | 
| 94 | 
            +
                def get_merged_dataset(dataset_names:list, transformation=None, train_subsample:int =None, val_subsample:int = 10000):
         | 
| 95 | 
            +
                    train_datasets = []
         | 
| 96 | 
            +
                    val_datasets = []
         | 
| 97 | 
            +
                    for dataset_name in dataset_names:
         | 
| 98 | 
            +
                        datasets = get_dataset(dataset_name, transformation, train_subsample, val_subsample)
         | 
| 99 | 
            +
                        train_datasets.append(datasets["train"])
         | 
| 100 | 
            +
                        val_datasets.append(datasets["val"])
         | 
| 101 | 
            +
                    train_datasets = MergeDataset(train_datasets).subsample(train_subsample)
         | 
| 102 | 
            +
                    val_datasets = MergeDataset(val_datasets).subsample(val_subsample)
         | 
| 103 | 
            +
                    return {"train": train_datasets, "val": val_datasets}
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def __init__(self, datasets:list):
         | 
| 106 | 
            +
                    self.datasets = datasets
         | 
| 107 | 
            +
                    self.column_names = self.datasets[0].column_names
         | 
| 108 | 
            +
                    # self.ids = []
         | 
| 109 | 
            +
                    # start = 0
         | 
| 110 | 
            +
                    # for dataset in self.datasets:
         | 
| 111 | 
            +
                    #     self.ids += [i+start for i in dataset.ids]
         | 
| 112 | 
            +
                def define_resolution(self, resolution: int):
         | 
| 113 | 
            +
                    for dataset in self.datasets:
         | 
| 114 | 
            +
                        dataset.define_resolution(resolution)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def __len__(self):
         | 
| 117 | 
            +
                    return sum([len(dataset) for dataset in self.datasets])
         | 
| 118 | 
            +
                def __getitem__(self, index):
         | 
| 119 | 
            +
                    for i,dataset in enumerate(self.datasets):
         | 
| 120 | 
            +
                        if index < len(dataset):
         | 
| 121 | 
            +
                            ret = dataset[index]
         | 
| 122 | 
            +
                            ret["id"] = index
         | 
| 123 | 
            +
                            ret["dataset"] = i
         | 
| 124 | 
            +
                            return ret
         | 
| 125 | 
            +
                        index -= len(dataset)
         | 
| 126 | 
            +
                    raise IndexError
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                def subsample(self, num:int):
         | 
| 129 | 
            +
                    if num is None:
         | 
| 130 | 
            +
                        return self
         | 
| 131 | 
            +
                    dataset_ratio = np.array([len(dataset) for dataset in self.datasets]) / len(self)
         | 
| 132 | 
            +
                    new_datasets = []
         | 
| 133 | 
            +
                    for i, dataset in enumerate(self.datasets):
         | 
| 134 | 
            +
                        new_datasets.append(dataset.subsample(int(num*dataset_ratio[i])))
         | 
| 135 | 
            +
                    return MergeDataset(new_datasets)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def with_transform(self, transform):
         | 
| 138 | 
            +
                    for dataset in self.datasets:
         | 
| 139 | 
            +
                        dataset.with_transform(transform)
         | 
| 140 | 
            +
                    return self
         | 
| 141 | 
            +
             | 
    	
        custom_datasets/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | Binary file (5.8 kB). View file | 
|  | 
    	
        custom_datasets/__pycache__/coco.cpython-39.pyc
    ADDED
    
    | Binary file (10.4 kB). View file | 
|  | 
    	
        custom_datasets/__pycache__/imagepair.cpython-39.pyc
    ADDED
    
    | Binary file (8.93 kB). View file | 
|  | 
    	
        custom_datasets/__pycache__/mypath.cpython-39.pyc
    ADDED
    
    | Binary file (1.49 kB). View file | 
|  | 
    	
        custom_datasets/coco.py
    ADDED
    
    | @@ -0,0 +1,307 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os.path
         | 
| 2 | 
            +
            from typing import Any, Callable, List, Optional, Tuple
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from PIL import Image
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from torchvision.datasets.vision import VisionDataset
         | 
| 7 | 
            +
            import pickle
         | 
| 8 | 
            +
            import csv
         | 
| 9 | 
            +
            import pandas as pd
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            import torchvision
         | 
| 12 | 
            +
            import re
         | 
| 13 | 
            +
            # from torchvision.datasets import CocoDetection
         | 
| 14 | 
            +
            # from utils.clip_filter import Clip_filter
         | 
| 15 | 
            +
            from tqdm import tqdm
         | 
| 16 | 
            +
            from .mypath import MyPath
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            class CocoDetection(VisionDataset):
         | 
| 19 | 
            +
                """`MS Coco Detection <https://cocodataset.org/#detection-2016>`_ Dataset.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Args:
         | 
| 24 | 
            +
                    root (string): Root directory where images are downloaded to.
         | 
| 25 | 
            +
                    annFile (string): Path to json annotation file.
         | 
| 26 | 
            +
                    transform (callable, optional): A function/transform that  takes in an PIL image
         | 
| 27 | 
            +
                        and returns a transformed version. E.g, ``transforms.PILToTensor``
         | 
| 28 | 
            +
                    target_transform (callable, optional): A function/transform that takes in the
         | 
| 29 | 
            +
                        target and transforms it.
         | 
| 30 | 
            +
                    transforms (callable, optional): A function/transform that takes input sample and its target as entry
         | 
| 31 | 
            +
                        and returns a transformed version.
         | 
| 32 | 
            +
                """
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __init__(
         | 
| 35 | 
            +
                        self,
         | 
| 36 | 
            +
                        root: str ,
         | 
| 37 | 
            +
                        annFile: str,
         | 
| 38 | 
            +
                        transform: Optional[Callable] = None,
         | 
| 39 | 
            +
                        target_transform: Optional[Callable] = None,
         | 
| 40 | 
            +
                        transforms: Optional[Callable] = None,
         | 
| 41 | 
            +
                        get_img=True,
         | 
| 42 | 
            +
                        get_cap=True
         | 
| 43 | 
            +
                ) -> None:
         | 
| 44 | 
            +
                    super().__init__(root, transforms, transform, target_transform)
         | 
| 45 | 
            +
                    from pycocotools.coco import COCO
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    self.coco = COCO(annFile)
         | 
| 48 | 
            +
                    self.ids = list(sorted(self.coco.imgs.keys()))
         | 
| 49 | 
            +
                    self.column_names = ["image", "text"]
         | 
| 50 | 
            +
                    self.get_img = get_img
         | 
| 51 | 
            +
                    self.get_cap = get_cap
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def _load_image(self, id: int) -> Image.Image:
         | 
| 54 | 
            +
                    path = self.coco.loadImgs(id)[0]["file_name"]
         | 
| 55 | 
            +
                    with open(os.path.join(self.root, path), 'rb') as f:
         | 
| 56 | 
            +
                        img = Image.open(f).convert("RGB")
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    return img
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def _load_target(self, id: int) -> List[Any]:
         | 
| 61 | 
            +
                    return self.coco.loadAnns(self.coco.getAnnIds(id))
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                def __getitem__(self, index: int) -> Tuple[Any, Any]:
         | 
| 64 | 
            +
                    id = self.ids[index]
         | 
| 65 | 
            +
                    ret={"id":id}
         | 
| 66 | 
            +
                    if self.get_img:
         | 
| 67 | 
            +
                        image = self._load_image(id)
         | 
| 68 | 
            +
                        ret["image"] = image
         | 
| 69 | 
            +
                    if self.get_cap:
         | 
| 70 | 
            +
                        target = self._load_target(id)
         | 
| 71 | 
            +
                        ret["caption"] = [target]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    if self.transforms is not None:
         | 
| 74 | 
            +
                        ret = self.transforms(ret)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    return ret
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def subsample(self, n: int = 10000):
         | 
| 79 | 
            +
                    if n is None or n == -1:
         | 
| 80 | 
            +
                        return self
         | 
| 81 | 
            +
                    ori_len = len(self)
         | 
| 82 | 
            +
                    assert n <= ori_len
         | 
| 83 | 
            +
                    # equal interval subsample
         | 
| 84 | 
            +
                    ids = self.ids[::ori_len // n][:n]
         | 
| 85 | 
            +
                    self.ids = ids
         | 
| 86 | 
            +
                    print(f"COCO dataset subsampled from {ori_len} to {len(self)}")
         | 
| 87 | 
            +
                    return self
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
                def with_transform(self, transform):
         | 
| 91 | 
            +
                    self.transforms = transform
         | 
| 92 | 
            +
                    return self
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def __len__(self) -> int:
         | 
| 95 | 
            +
                    # return 100
         | 
| 96 | 
            +
                    return len(self.ids)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
             | 
| 99 | 
            +
            class CocoCaptions(CocoDetection):
         | 
| 100 | 
            +
                """`MS Coco Captions <https://cocodataset.org/#captions-2015>`_ Dataset.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                It requires the `COCO API to be installed <https://github.com/pdollar/coco/tree/master/PythonAPI>`_.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                Args:
         | 
| 105 | 
            +
                    root (string): Root directory where images are downloaded to.
         | 
| 106 | 
            +
                    annFile (string): Path to json annotation file.
         | 
| 107 | 
            +
                    transform (callable, optional): A function/transform that  takes in an PIL image
         | 
| 108 | 
            +
                        and returns a transformed version. E.g, ``transforms.PILToTensor``
         | 
| 109 | 
            +
                    target_transform (callable, optional): A function/transform that takes in the
         | 
| 110 | 
            +
                        target and transforms it.
         | 
| 111 | 
            +
                    transforms (callable, optional): A function/transform that takes input sample and its target as entry
         | 
| 112 | 
            +
                        and returns a transformed version.
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                Example:
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    .. code:: python
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                        import torchvision.datasets as dset
         | 
| 119 | 
            +
                        import torchvision.transforms as transforms
         | 
| 120 | 
            +
                        cap = dset.CocoCaptions(root = 'dir where images are',
         | 
| 121 | 
            +
                                                annFile = 'json annotation file',
         | 
| 122 | 
            +
                                                transform=transforms.PILToTensor())
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        print('Number of samples: ', len(cap))
         | 
| 125 | 
            +
                        img, target = cap[3] # load 4th sample
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                        print("Image Size: ", img.size())
         | 
| 128 | 
            +
                        print(target)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    Output: ::
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                        Number of samples: 82783
         | 
| 133 | 
            +
                        Image Size: (3L, 427L, 640L)
         | 
| 134 | 
            +
                        [u'A plane emitting smoke stream flying over a mountain.',
         | 
| 135 | 
            +
                        u'A plane darts across a bright blue sky behind a mountain covered in snow',
         | 
| 136 | 
            +
                        u'A plane leaves a contrail above the snowy mountain top.',
         | 
| 137 | 
            +
                        u'A mountain that has a plane flying overheard in the distance.',
         | 
| 138 | 
            +
                        u'A mountain view with a plume of smoke in the background']
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                """
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def _load_target(self, id: int) -> List[str]:
         | 
| 143 | 
            +
                    return [ann["caption"] for ann in super()._load_target(id)]
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
            class CocoCaptions_clip_filtered(CocoCaptions):
         | 
| 147 | 
            +
                positive_prompt=["painting", "drawing", "graffiti",]
         | 
| 148 | 
            +
                def __init__(
         | 
| 149 | 
            +
                        self,
         | 
| 150 | 
            +
                        root: str ,
         | 
| 151 | 
            +
                        annFile: str,
         | 
| 152 | 
            +
                        transform: Optional[Callable] = None,
         | 
| 153 | 
            +
                        target_transform: Optional[Callable] = None,
         | 
| 154 | 
            +
                        transforms: Optional[Callable] = None,
         | 
| 155 | 
            +
                        regenerate: bool = False,
         | 
| 156 | 
            +
                        id_file: Optional[str] = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_ids.pickle"
         | 
| 157 | 
            +
                ) -> None:
         | 
| 158 | 
            +
                    super().__init__(root, annFile, transform, target_transform, transforms)
         | 
| 159 | 
            +
                    os.makedirs(os.path.dirname(id_file), exist_ok=True)
         | 
| 160 | 
            +
                    if os.path.exists(id_file) and not regenerate:
         | 
| 161 | 
            +
                        with open(id_file, "rb") as f:
         | 
| 162 | 
            +
                            self.ids = pickle.load(f)
         | 
| 163 | 
            +
                    else:
         | 
| 164 | 
            +
                        self.ids, naive_filtered_num = self.naive_filter()
         | 
| 165 | 
            +
                        self.ids, clip_filtered_num = self.clip_filter(0.7)
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                        print(f"naive Filtered {naive_filtered_num} images")
         | 
| 168 | 
            +
                        print(f"Clip Filtered {clip_filtered_num} images")
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                        with open(id_file, "wb") as f:
         | 
| 171 | 
            +
                            pickle.dump(self.ids, f)
         | 
| 172 | 
            +
                            print(f"Filtered ids saved to {id_file}")
         | 
| 173 | 
            +
                    print(f"COCO filtered dataset size: {len(self)}")
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def naive_filter(self, filter_prompt="painting"):
         | 
| 176 | 
            +
                    new_ids = []
         | 
| 177 | 
            +
                    naive_filtered_num = 0
         | 
| 178 | 
            +
                    for id in self.ids:
         | 
| 179 | 
            +
                        target = self._load_target(id)
         | 
| 180 | 
            +
                        filtered = False
         | 
| 181 | 
            +
                        for prompt in target:
         | 
| 182 | 
            +
                            if filter_prompt in prompt.lower():
         | 
| 183 | 
            +
                                filtered = True
         | 
| 184 | 
            +
                                naive_filtered_num += 1
         | 
| 185 | 
            +
                                break
         | 
| 186 | 
            +
                            # if "artwork" in prompt.lower():
         | 
| 187 | 
            +
                            #     pass
         | 
| 188 | 
            +
                        if not filtered:
         | 
| 189 | 
            +
                            new_ids.append(id)
         | 
| 190 | 
            +
                    return new_ids, naive_filtered_num
         | 
| 191 | 
            +
             | 
| 192 | 
            +
                # def clip_filter(self, threshold=0.7):
         | 
| 193 | 
            +
                #
         | 
| 194 | 
            +
                #     def collate_fn(examples):
         | 
| 195 | 
            +
                #         # {"image": image, "text": [target], "id":id}
         | 
| 196 | 
            +
                #         pixel_values = [example["image"] for example in examples]
         | 
| 197 | 
            +
                #         prompts = [example["text"] for example in examples]
         | 
| 198 | 
            +
                #         id = [example["id"] for example in examples]
         | 
| 199 | 
            +
                #         return {"images": pixel_values, "prompts": prompts, "ids": id}
         | 
| 200 | 
            +
                #
         | 
| 201 | 
            +
                #
         | 
| 202 | 
            +
                #     clip_filtered_num = 0
         | 
| 203 | 
            +
                #     clip_filter = Clip_filter(positive_prompt=self.positive_prompt)
         | 
| 204 | 
            +
                #     clip_logs={"positive_prompt":clip_filter.positive_prompt, "negative_prompt":clip_filter.negative_prompt,
         | 
| 205 | 
            +
                #                "ids":torch.Tensor([]),"logits":torch.Tensor([])}
         | 
| 206 | 
            +
                #     clip_log_file = "data/coco/clip_logs.pth"
         | 
| 207 | 
            +
                #     new_ids = []
         | 
| 208 | 
            +
                #     batch_size = 128
         | 
| 209 | 
            +
                #     dataloader = torch.utils.data.DataLoader(self, batch_size=batch_size, num_workers=10, shuffle=False,
         | 
| 210 | 
            +
                #                                              collate_fn=collate_fn)
         | 
| 211 | 
            +
                #     for i, batch in enumerate(tqdm(dataloader)):
         | 
| 212 | 
            +
                #         images = batch["images"]
         | 
| 213 | 
            +
                #         filter_result, logits = clip_filter.filter(images, threshold=threshold)
         | 
| 214 | 
            +
                #         ids = torch.IntTensor(batch["ids"])
         | 
| 215 | 
            +
                #         clip_logs["ids"] = torch.cat([clip_logs["ids"], ids])
         | 
| 216 | 
            +
                #         clip_logs["logits"] = torch.cat([clip_logs["logits"], logits])
         | 
| 217 | 
            +
                #
         | 
| 218 | 
            +
                #         new_ids.extend(ids[~filter_result].tolist())
         | 
| 219 | 
            +
                #         clip_filtered_num += filter_result.sum().item()
         | 
| 220 | 
            +
                #         if i % 50 == 0:
         | 
| 221 | 
            +
                #             torch.save(clip_logs, clip_log_file)
         | 
| 222 | 
            +
                #     torch.save(clip_logs, clip_log_file)
         | 
| 223 | 
            +
                #
         | 
| 224 | 
            +
                #     return new_ids, clip_filtered_num
         | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
            class CustomCocoCaptions(CocoCaptions):
         | 
| 228 | 
            +
                def __init__(self, root: str=MyPath.db_root_dir("coco_val"), annFile: str=MyPath.db_root_dir("coco_caption_val"), custom_file:str="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt",transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, transforms: Optional[Callable] = None) -> None:
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    super().__init__(root, annFile, transform, target_transform, transforms)
         | 
| 231 | 
            +
                    self.column_names = ["image", "text"]
         | 
| 232 | 
            +
                    self.custom_file = custom_file
         | 
| 233 | 
            +
                    self.load_custom_data(custom_file)
         | 
| 234 | 
            +
                    self.transforms = transforms
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                def load_custom_data(self, custom_file):
         | 
| 237 | 
            +
                    self.custom_data = []
         | 
| 238 | 
            +
                    with open(custom_file, "r") as f:
         | 
| 239 | 
            +
                        data = f.readlines()
         | 
| 240 | 
            +
                    head = data[0].strip().split(",")
         | 
| 241 | 
            +
                    self.head = head
         | 
| 242 | 
            +
                    for line in data[1:]:
         | 
| 243 | 
            +
                        sub_data = line.strip().split(",")
         | 
| 244 | 
            +
                        if len(sub_data) > len(head):
         | 
| 245 | 
            +
                            sub_data_new = [sub_data[0]]
         | 
| 246 | 
            +
                            sub_data_new+=[",".join(sub_data[1:-1])]
         | 
| 247 | 
            +
                            sub_data_new.append(sub_data[-1])
         | 
| 248 | 
            +
                            sub_data = sub_data_new
         | 
| 249 | 
            +
                        assert len(sub_data) == len(head)
         | 
| 250 | 
            +
                        self.custom_data.append(sub_data)
         | 
| 251 | 
            +
                    # to pd
         | 
| 252 | 
            +
                    self.custom_data = pd.DataFrame(self.custom_data, columns=head)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                def __len__(self) -> int:
         | 
| 255 | 
            +
                    return len(self.custom_data)
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def __getitem__(self, index: int) -> Tuple[Any, Any]:
         | 
| 258 | 
            +
                    data = self.custom_data.iloc[index]
         | 
| 259 | 
            +
                    id = int(data["image_id"])
         | 
| 260 | 
            +
                    ret={"id":id}
         | 
| 261 | 
            +
                    if self.get_img:
         | 
| 262 | 
            +
                        image = self._load_image(id)
         | 
| 263 | 
            +
                        ret["image"] = image
         | 
| 264 | 
            +
                    if self.get_cap:
         | 
| 265 | 
            +
                        caption = data["caption"]
         | 
| 266 | 
            +
                        ret["caption"] = [caption]
         | 
| 267 | 
            +
                    ret["seed"] = int(data["random_seed"])
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    if self.transforms is not None:
         | 
| 270 | 
            +
                        ret = self.transforms(ret)
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    return ret
         | 
| 273 | 
            +
             | 
| 274 | 
            +
             | 
| 275 | 
            +
             | 
| 276 | 
            +
            def get_validation_set():
         | 
| 277 | 
            +
                coco_instance = CocoDetection(root="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/train2017/", annFile="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/.datasets/coco_2017/annotations/instances_train2017.json")
         | 
| 278 | 
            +
                discard_cat_id = coco_instance.coco.getCatIds(supNms=["person", "animal"])
         | 
| 279 | 
            +
                discard_img_id = []
         | 
| 280 | 
            +
                for cat_id in discard_cat_id:
         | 
| 281 | 
            +
                    discard_img_id += coco_instance.coco.catToImgs[cat_id]
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                coco_clip_filtered = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"),
         | 
| 284 | 
            +
                                            regenerate=False)
         | 
| 285 | 
            +
                coco_clip_filtered_ids = coco_clip_filtered.ids
         | 
| 286 | 
            +
                new_ids = set(coco_clip_filtered_ids) - set(discard_img_id)
         | 
| 287 | 
            +
                new_ids = list(new_ids)
         | 
| 288 | 
            +
                new_ids = random.sample(new_ids, 100)
         | 
| 289 | 
            +
                with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/data/coco/coco_clip_filtered_subset100.pickle", "wb") as f:
         | 
| 290 | 
            +
                    pickle.dump(new_ids, f)
         | 
| 291 | 
            +
             | 
| 292 | 
            +
            if __name__ == "__main__":
         | 
| 293 | 
            +
                from mypath import MyPath
         | 
| 294 | 
            +
                import random
         | 
| 295 | 
            +
                # get_validation_set()
         | 
| 296 | 
            +
                # coco_filtered_remian_id = pickle.load(open("data/coco/coco_clip_filtered_ids.pickle", "rb"))
         | 
| 297 | 
            +
                #
         | 
| 298 | 
            +
                # coco_filtered_subset100 = random.sample(coco_filtered_remian_id, 100)
         | 
| 299 | 
            +
                # save_path = "data/coco/coco_clip_filtered_subset100.pickle"
         | 
| 300 | 
            +
                # with open(save_path, "wb") as f:
         | 
| 301 | 
            +
                #     pickle.dump(coco_filtered_subset100, f)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                # dataset = CocoCaptions_clip_filtered(root=MyPath.db_root_dir("coco_train"), annFile=MyPath.db_root_dir("coco_caption_train"),
         | 
| 304 | 
            +
                #                                 regenerate=False)
         | 
| 305 | 
            +
                dataset = CustomCocoCaptions(root=MyPath.db_root_dir("coco_val"), annFile=MyPath.db_root_dir("coco_caption_val"),
         | 
| 306 | 
            +
                                             custom_file="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/jomat-code/filtering/ms_coco_captions_testset100.txt")
         | 
| 307 | 
            +
                dataset[0]
         | 
    	
        custom_datasets/custom_caption.py
    ADDED
    
    | @@ -0,0 +1,113 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            import pandas as pd
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            from PIL import Image
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            class Caption_set(torch.utils.data.Dataset):
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                style_set_names=[
         | 
| 11 | 
            +
                    "andre-derain_subset1",
         | 
| 12 | 
            +
                    "andy_subset1",
         | 
| 13 | 
            +
                    "camille-corot_subset1",
         | 
| 14 | 
            +
                    "gerhard-richter_subset1",
         | 
| 15 | 
            +
                    "henri-matisse_subset1",
         | 
| 16 | 
            +
                    "katsushika-hokusai_subset1",
         | 
| 17 | 
            +
                    "klimt_subset3",
         | 
| 18 | 
            +
                    "monet_subset2",
         | 
| 19 | 
            +
                    "picasso_subset1",
         | 
| 20 | 
            +
                    "van_gogh_subset1",
         | 
| 21 | 
            +
                ]
         | 
| 22 | 
            +
                style_set_map={f"{name}":f"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/{name}/style_captions.csv" for name in style_set_names}
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def __init__(self, prompts_path=None, set_name=None, transform=None):
         | 
| 25 | 
            +
                    assert prompts_path is not None or set_name is not None, "Either prompts_path or set_name should be provided"
         | 
| 26 | 
            +
                    if prompts_path is None:
         | 
| 27 | 
            +
                        prompts_path = self.style_set_map[set_name]
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    self.prompts = pd.read_csv(prompts_path, delimiter=';')
         | 
| 30 | 
            +
                    self.transform = transform
         | 
| 31 | 
            +
                def __len__(self):
         | 
| 32 | 
            +
                    return len(self.prompts)
         | 
| 33 | 
            +
                def __getitem__(self, idx):
         | 
| 34 | 
            +
                    ret={}
         | 
| 35 | 
            +
                    ret["id"] = idx
         | 
| 36 | 
            +
                    info = self.prompts.iloc[idx]
         | 
| 37 | 
            +
                    ret.update(info)
         | 
| 38 | 
            +
                    for k,v in ret.items():
         | 
| 39 | 
            +
                        if isinstance(v,np.int64):
         | 
| 40 | 
            +
                            ret[k] = int(v)
         | 
| 41 | 
            +
                    ret["caption"] = [ret["caption"]]
         | 
| 42 | 
            +
                    if self.transform:
         | 
| 43 | 
            +
                        ret = self.transform(ret)
         | 
| 44 | 
            +
                    return ret
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def with_transform(self, transform):
         | 
| 47 | 
            +
                    self.transform = transform
         | 
| 48 | 
            +
                    return self
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            class HRS_caption(Caption_set):
         | 
| 52 | 
            +
                def __init__(self, prompts_path="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv", transform=None, delimiter=','):
         | 
| 53 | 
            +
                    self.prompts = pd.read_csv(prompts_path, delimiter=delimiter)
         | 
| 54 | 
            +
                    self.transform = transform
         | 
| 55 | 
            +
                    self.caption_key = "original_prompts"
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __getitem__(self, idx):
         | 
| 58 | 
            +
                    ret={}
         | 
| 59 | 
            +
                    ret["id"] = idx
         | 
| 60 | 
            +
                    info = self.prompts.iloc[idx]
         | 
| 61 | 
            +
                    ret["caption"] = [info[self.caption_key]]
         | 
| 62 | 
            +
                    ret["seed"] = idx
         | 
| 63 | 
            +
                    if self.transform:
         | 
| 64 | 
            +
                        ret = self.transform(ret)
         | 
| 65 | 
            +
                    return ret
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            class Laion_pop(torch.utils.data.Dataset):
         | 
| 68 | 
            +
                def __init__(self, anno_file="/vision-nfs/torralba/projects/jomat/hui/stable_diffusion/custom_datasets/laion_pop500.csv",image_root="/vision-nfs/torralba/scratch/jomat/sam_dataset/laion_pop",transform=None):
         | 
| 69 | 
            +
                    self.transform = transform
         | 
| 70 | 
            +
                    self.info = pd.read_csv(anno_file, delimiter=";")
         | 
| 71 | 
            +
                    self.caption_key = "caption"
         | 
| 72 | 
            +
                    self.image_root = image_root
         | 
| 73 | 
            +
                    self.get_img=True
         | 
| 74 | 
            +
                    self.get_caption=True
         | 
| 75 | 
            +
                def __len__(self):
         | 
| 76 | 
            +
                    return len(self.info)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                # def subsample(self, num:int):
         | 
| 79 | 
            +
                #     self.data = self.data.select(range(num))
         | 
| 80 | 
            +
                #     return self
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def load_image(self, key):
         | 
| 83 | 
            +
                    image_path = os.path.join(self.image_root, f"{key:09}.jpg")
         | 
| 84 | 
            +
                    with open(image_path, "rb") as f:
         | 
| 85 | 
            +
                        image = Image.open(f).convert("RGB")
         | 
| 86 | 
            +
                    return image
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __getitem__(self, idx):
         | 
| 89 | 
            +
                    info = self.info.iloc[idx]
         | 
| 90 | 
            +
                    ret = {}
         | 
| 91 | 
            +
                    key = info["key"]
         | 
| 92 | 
            +
                    ret["id"] = key
         | 
| 93 | 
            +
                    if self.get_caption:
         | 
| 94 | 
            +
                        ret["caption"] = [info[self.caption_key]]
         | 
| 95 | 
            +
                    ret["seed"] = int(key)
         | 
| 96 | 
            +
                    if self.get_img:
         | 
| 97 | 
            +
                        ret["image"] = self.load_image(key)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    if self.transform:
         | 
| 100 | 
            +
                        ret = self.transform(ret)
         | 
| 101 | 
            +
                    return ret
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def with_transform(self, transform):
         | 
| 104 | 
            +
                    self.transform = transform
         | 
| 105 | 
            +
                    return self
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def subset(self, ids:list):
         | 
| 108 | 
            +
                    self.info = self.info[self.info["key"].isin(ids)]
         | 
| 109 | 
            +
                    return self
         | 
| 110 | 
            +
             | 
| 111 | 
            +
            if __name__ == "__main__":
         | 
| 112 | 
            +
                dataset = Caption_set("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Style_captions/andre-derain_subset1/style_captions.csv")
         | 
| 113 | 
            +
                dataset[0]
         | 
    	
        custom_datasets/filt/coco/filt.py
    ADDED
    
    | @@ -0,0 +1,186 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import pickle
         | 
| 7 | 
            +
            sys.path.append(os.path.join(os.path.dirname(__file__), "../../../"))
         | 
| 8 | 
            +
            from custom_datasets import get_dataset
         | 
| 9 | 
            +
            from utils.art_filter import Art_filter
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from matplotlib import pyplot as plt
         | 
| 12 | 
            +
            import math
         | 
| 13 | 
            +
            import argparse
         | 
| 14 | 
            +
            import socket
         | 
| 15 | 
            +
            import time
         | 
| 16 | 
            +
            from tqdm import tqdm
         | 
| 17 | 
            +
            import torch
         | 
| 18 | 
            +
            def parse_args():
         | 
| 19 | 
            +
                parser = argparse.ArgumentParser(description="Filter the coco dataset")
         | 
| 20 | 
            +
                parser.add_argument("--check", action="store_true", help="Check the complete")
         | 
| 21 | 
            +
                parser.add_argument("--mode", default="clip_logit", help="Filter mode: clip_logit, clip_filt, caption_filt")
         | 
| 22 | 
            +
                parser.add_argument("--split" , default="val", help="Dataset split, val/train")
         | 
| 23 | 
            +
                # parser.add_argument("--start_idx", default=0, type=int, help="Start index")
         | 
| 24 | 
            +
                args = parser.parse_args()
         | 
| 25 | 
            +
                return args
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            def get_feat(save_path, dataloader, filter):
         | 
| 28 | 
            +
                clip_feat_file = save_path
         | 
| 29 | 
            +
                # compute_new = False
         | 
| 30 | 
            +
                clip_feat={}
         | 
| 31 | 
            +
                if os.path.exists(clip_feat_file):
         | 
| 32 | 
            +
                    with open(clip_feat_file, 'rb') as f:
         | 
| 33 | 
            +
                        clip_feat = pickle.load(f)
         | 
| 34 | 
            +
                else:
         | 
| 35 | 
            +
                    print(f"computing clip feat",flush=True)
         | 
| 36 | 
            +
                    clip_feature_ret = filter.clip_feature(dataloader)
         | 
| 37 | 
            +
                    clip_feat["image_features"] = clip_feature_ret["clip_features"]
         | 
| 38 | 
            +
                    clip_feat["ids"] = clip_feature_ret["ids"]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    with open(clip_feat_file, 'wb') as f:
         | 
| 41 | 
            +
                        pickle.dump(clip_feat, f)
         | 
| 42 | 
            +
                    print(f"clip_feat_result saved to {clip_feat_file}",flush=True)
         | 
| 43 | 
            +
                return clip_feat
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            def get_clip_logit(save_root, dataloader, filter):
         | 
| 46 | 
            +
                feat_path = os.path.join(save_root, "clip_feat.pickle")
         | 
| 47 | 
            +
                clip_feat = get_feat(feat_path, dataloader, filter)
         | 
| 48 | 
            +
                clip_logits_file = os.path.join(save_root, "clip_logits.pickle")
         | 
| 49 | 
            +
                # if clip_logit:
         | 
| 50 | 
            +
                if os.path.exists(clip_logits_file):
         | 
| 51 | 
            +
                    with open(clip_logits_file, 'rb') as f:
         | 
| 52 | 
            +
                        clip_logits = pickle.load(f)
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                    clip_logits = filter.clip_logit_by_feat(clip_feat["image_features"])
         | 
| 55 | 
            +
                    clip_logits["ids"] = clip_feat["ids"]
         | 
| 56 | 
            +
                    with open(clip_logits_file, 'wb') as f:
         | 
| 57 | 
            +
                        pickle.dump(clip_logits, f)
         | 
| 58 | 
            +
                    print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
         | 
| 59 | 
            +
                return clip_logits
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            def clip_filt(save_root, dataloader, filter):
         | 
| 62 | 
            +
                clip_filt_file = os.path.join(save_root, "clip_filt_result.pickle")
         | 
| 63 | 
            +
                if os.path.exists(clip_filt_file):
         | 
| 64 | 
            +
                    with open(clip_filt_file, 'rb') as f:
         | 
| 65 | 
            +
                        clip_filt_result = pickle.load(f)
         | 
| 66 | 
            +
                else:
         | 
| 67 | 
            +
                    clip_logits = get_clip_logit(save_root, dataloader, filter)
         | 
| 68 | 
            +
                    clip_filt_result = filter.clip_filt(clip_logits)
         | 
| 69 | 
            +
                    with open(clip_filt_file, 'wb') as f:
         | 
| 70 | 
            +
                        pickle.dump(clip_filt_result, f)
         | 
| 71 | 
            +
                    print(f"clip_filt_result saved to {clip_filt_file}",flush=True)
         | 
| 72 | 
            +
                return clip_filt_result
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            def caption_filt(save_root, dataloader, filter):
         | 
| 75 | 
            +
                caption_filt_file = os.path.join(save_root, "caption_filt_result.pickle")
         | 
| 76 | 
            +
                if os.path.exists(caption_filt_file):
         | 
| 77 | 
            +
                    with open(caption_filt_file, 'rb') as f:
         | 
| 78 | 
            +
                        caption_filt_result = pickle.load(f)
         | 
| 79 | 
            +
                else:
         | 
| 80 | 
            +
                    caption_filt_result = filter.caption_filt(dataloader)
         | 
| 81 | 
            +
                    with open(caption_filt_file, 'wb') as f:
         | 
| 82 | 
            +
                        pickle.dump(caption_filt_result, f)
         | 
| 83 | 
            +
                    print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
         | 
| 84 | 
            +
                return caption_filt_result
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def gather_result(save_dir, dataloader, filter):
         | 
| 87 | 
            +
                all_remain_ids=[]
         | 
| 88 | 
            +
                all_remain_ids_train=[]
         | 
| 89 | 
            +
                all_remain_ids_val=[]
         | 
| 90 | 
            +
                all_filtered_id_num = 0
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                clip_filt_result = clip_filt(save_dir, dataloader, filter)
         | 
| 93 | 
            +
                caption_filt_result = caption_filt(save_dir, dataloader, filter)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
         | 
| 96 | 
            +
                all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )
         | 
| 97 | 
            +
                remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
         | 
| 98 | 
            +
                remain_ids = list(remain_ids)
         | 
| 99 | 
            +
                remain_ids.sort()
         | 
| 100 | 
            +
                with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
         | 
| 101 | 
            +
                    pickle.dump(remain_ids, f)
         | 
| 102 | 
            +
                print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
         | 
| 103 | 
            +
                return remain_ids
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            @torch.no_grad()
         | 
| 106 | 
            +
            def main(args):
         | 
| 107 | 
            +
                filter = Art_filter()
         | 
| 108 | 
            +
                if args.mode == "caption_filt" or args.mode == "gather_result":
         | 
| 109 | 
            +
                    filter.clip_filter = None
         | 
| 110 | 
            +
                    torch.cuda.empty_cache()
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                # caption_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/PixArt-alpha/captions"
         | 
| 113 | 
            +
                # image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images"
         | 
| 114 | 
            +
                # id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict"
         | 
| 115 | 
            +
                # filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                def collate_fn(examples):
         | 
| 118 | 
            +
                    # {"image": image, "id":id}
         | 
| 119 | 
            +
                    ret = {}
         | 
| 120 | 
            +
                    if "image" in examples[0]:
         | 
| 121 | 
            +
                        pixel_values = [example["image"] for example in examples]
         | 
| 122 | 
            +
                        ret["images"] = pixel_values
         | 
| 123 | 
            +
                    if "caption" in examples[0]:
         | 
| 124 | 
            +
                        # prompts = [example["caption"] for example in examples]
         | 
| 125 | 
            +
                        prompts = []
         | 
| 126 | 
            +
                        for example in examples:
         | 
| 127 | 
            +
                            if isinstance(example["caption"][0], list):
         | 
| 128 | 
            +
                                prompts.append([" ".join(example["caption"][0])])
         | 
| 129 | 
            +
                            else:
         | 
| 130 | 
            +
                                prompts.append(example["caption"])
         | 
| 131 | 
            +
                        ret["text"] = prompts
         | 
| 132 | 
            +
                    id = [example["id"] for example in examples]
         | 
| 133 | 
            +
                    ret["ids"] = id
         | 
| 134 | 
            +
                    return ret
         | 
| 135 | 
            +
                if args.split == "val":
         | 
| 136 | 
            +
                    dataset = get_dataset("coco_val")["val"]
         | 
| 137 | 
            +
                elif args.split == "train":
         | 
| 138 | 
            +
                    dataset = get_dataset("coco_train", get_val=False)["train"]
         | 
| 139 | 
            +
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                error_files=[]
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
                save_root = f"/vision-nfs/torralba/scratch/jomat/sam_dataset/coco/filt/{args.split}"
         | 
| 146 | 
            +
                os.makedirs(save_root, exist_ok=True)
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                if args.mode == "clip_feat":
         | 
| 149 | 
            +
                    feat_path = os.path.join(save_root, "clip_feat.pickle")
         | 
| 150 | 
            +
                    clip_feat = get_feat(feat_path, dataloader, filter)
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                if args.mode == "clip_logit":
         | 
| 153 | 
            +
                    clip_logit = get_clip_logit(save_root, dataloader, filter)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                if args.mode == "clip_filt":
         | 
| 156 | 
            +
                    # if os.path.exists(clip_filt_file):
         | 
| 157 | 
            +
                    #     with open(clip_filt_file, 'rb') as f:
         | 
| 158 | 
            +
                    #         ret = pickle.load(f)
         | 
| 159 | 
            +
                    # else:
         | 
| 160 | 
            +
                    clip_filt_result = clip_filt(save_root, dataloader, filter)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                if args.mode == "caption_filt":
         | 
| 163 | 
            +
                    caption_filt_result = caption_filt(save_root, dataloader, filter)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                if args.mode == "gather_result":
         | 
| 166 | 
            +
                    filtered_result = gather_result(save_root, dataloader, filter)
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                print("finished",flush=True)
         | 
| 169 | 
            +
                for file in error_files:
         | 
| 170 | 
            +
                    # os.remove(file)
         | 
| 171 | 
            +
                    print(file,flush=True)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
            if __name__ == "__main__":
         | 
| 174 | 
            +
                args = parse_args()
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                log_file = "sam_filt"
         | 
| 177 | 
            +
                idx=0
         | 
| 178 | 
            +
                hostname = socket.gethostname()
         | 
| 179 | 
            +
                now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
         | 
| 180 | 
            +
                while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
         | 
| 181 | 
            +
                    idx+=1
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                main(args)
         | 
| 184 | 
            +
                # clip_logits_analysis()
         | 
| 185 | 
            +
             | 
| 186 | 
            +
             | 
    	
        custom_datasets/filt/sam_filt.py
    ADDED
    
    | @@ -0,0 +1,299 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import pickle
         | 
| 7 | 
            +
            sys.path.append(os.path.join(os.path.dirname(__file__), "../../"))
         | 
| 8 | 
            +
            from custom_datasets.sam import SamDataset
         | 
| 9 | 
            +
            from utils.art_filter import Art_filter
         | 
| 10 | 
            +
            import torch
         | 
| 11 | 
            +
            from matplotlib import pyplot as plt
         | 
| 12 | 
            +
            import math
         | 
| 13 | 
            +
            import argparse
         | 
| 14 | 
            +
            import socket
         | 
| 15 | 
            +
            import time
         | 
| 16 | 
            +
            from tqdm import tqdm
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            def parse_args():
         | 
| 19 | 
            +
                parser = argparse.ArgumentParser(description="Filter the sam dataset")
         | 
| 20 | 
            +
                parser.add_argument("--check", action="store_true", help="Check the complete")
         | 
| 21 | 
            +
                parser.add_argument("--mode", default="clip_logit",  choices=["clip_logit_update","clip_logit", "clip_filt", "caption_filt", "gather_result","caption_flit_append"])
         | 
| 22 | 
            +
                parser.add_argument("--start_idx", default=0, type=int, help="Start index")
         | 
| 23 | 
            +
                parser.add_argument("--end_idx", default=9e10, type=int, help="Start index")
         | 
| 24 | 
            +
                args = parser.parse_args()
         | 
| 25 | 
            +
                return args
         | 
| 26 | 
            +
            @torch.no_grad()
         | 
| 27 | 
            +
            def main(args):
         | 
| 28 | 
            +
                filter = Art_filter()
         | 
| 29 | 
            +
                if args.mode == "caption_filt" or args.mode == "gather_result":
         | 
| 30 | 
            +
                    filter.clip_filter = None
         | 
| 31 | 
            +
                    torch.cuda.empty_cache()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                caption_folder_path = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/SAM/subset/captions"
         | 
| 34 | 
            +
                image_folder_path = "/vision-nfs/torralba/scratch/jomat/sam_dataset/nfs-data/sam/images"
         | 
| 35 | 
            +
                id_dict_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/sam_ids/8.16/id_dict"
         | 
| 36 | 
            +
                filt_dir = "/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result"
         | 
| 37 | 
            +
                def collate_fn(examples):
         | 
| 38 | 
            +
                    # {"image": image, "id":id}
         | 
| 39 | 
            +
                    ret = {}
         | 
| 40 | 
            +
                    if "image" in examples[0]:
         | 
| 41 | 
            +
                        pixel_values = [example["image"] for example in examples]
         | 
| 42 | 
            +
                        ret["images"] = pixel_values
         | 
| 43 | 
            +
                    if "text" in examples[0]:
         | 
| 44 | 
            +
                        prompts = [example["text"] for example in examples]
         | 
| 45 | 
            +
                        ret["text"] = prompts
         | 
| 46 | 
            +
                    id = [example["id"] for example in examples]
         | 
| 47 | 
            +
                    ret["ids"] = id
         | 
| 48 | 
            +
                    return ret
         | 
| 49 | 
            +
                error_files=[]
         | 
| 50 | 
            +
                val_set = ["sa_000000"]
         | 
| 51 | 
            +
                result_check_set = ["sa_000020"]
         | 
| 52 | 
            +
                all_remain_ids=[]
         | 
| 53 | 
            +
                all_remain_ids_train=[]
         | 
| 54 | 
            +
                all_remain_ids_val=[]
         | 
| 55 | 
            +
                all_filtered_id_num = 0
         | 
| 56 | 
            +
                remain_feat_num = 0
         | 
| 57 | 
            +
                remain_caption_num = 0
         | 
| 58 | 
            +
                filter_feat_num = 0
         | 
| 59 | 
            +
                filter_caption_num = 0
         | 
| 60 | 
            +
                for idx,file in tqdm(enumerate(sorted(os.listdir(id_dict_dir)))):
         | 
| 61 | 
            +
                    if idx < args.start_idx or idx >= args.end_idx:
         | 
| 62 | 
            +
                        continue
         | 
| 63 | 
            +
                    if file.endswith(".pickle") and not file.startswith("all"):
         | 
| 64 | 
            +
                        print("=====================================")
         | 
| 65 | 
            +
                        print(file,flush=True)
         | 
| 66 | 
            +
                        save_dir = os.path.join(filt_dir, file.replace("_id_dict.pickle", ""))
         | 
| 67 | 
            +
                        if not os.path.exists(save_dir):
         | 
| 68 | 
            +
                            os.makedirs(save_dir, exist_ok=True)
         | 
| 69 | 
            +
                        id_dict_file = os.path.join(id_dict_dir, file)
         | 
| 70 | 
            +
                        with open(id_dict_file, 'rb') as f:
         | 
| 71 | 
            +
                            id_dict = pickle.load(f)
         | 
| 72 | 
            +
                        ids = list(id_dict.keys())
         | 
| 73 | 
            +
                        dataset = SamDataset(image_folder_path, caption_folder_path, id_file=ids, id_dict_file=id_dict_file)
         | 
| 74 | 
            +
                        # dataset = SamDataset(image_folder_path, caption_folder_path, id_file=[10061410, 10076945, 10310013,1042012, 4487809, 4541052], id_dict_file="/vision-nfs/torralba/scratch/jomat/sam_dataset/images/id_dict/all_id_dict.pickle")
         | 
| 75 | 
            +
                        dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
         | 
| 76 | 
            +
                        clip_logits = None
         | 
| 77 | 
            +
                        clip_logits_file = os.path.join(save_dir, "clip_logits_result.pickle")
         | 
| 78 | 
            +
                        clip_filt_file = os.path.join(save_dir, "clip_filt_result.pickle")
         | 
| 79 | 
            +
                        caption_filt_file = os.path.join(save_dir, "caption_filt_result.pickle")
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                        if args.mode == "clip_feat":
         | 
| 82 | 
            +
                            compute_new = False
         | 
| 83 | 
            +
                            clip_logits = {}
         | 
| 84 | 
            +
                            if os.path.exists(clip_logits_file):
         | 
| 85 | 
            +
                                with open(clip_logits_file, 'rb') as f:
         | 
| 86 | 
            +
                                    clip_logits = pickle.load(f)
         | 
| 87 | 
            +
                                if "image_features" not in clip_logits:
         | 
| 88 | 
            +
                                    compute_new = True
         | 
| 89 | 
            +
                            else:
         | 
| 90 | 
            +
                                compute_new=True
         | 
| 91 | 
            +
                            if compute_new:
         | 
| 92 | 
            +
                                if clip_logits == '':
         | 
| 93 | 
            +
                                    clip_logits = {}
         | 
| 94 | 
            +
                                print(f"compute clip_feat {file}",flush=True)
         | 
| 95 | 
            +
                                clip_feature_ret = filter.clip_feature(dataloader)
         | 
| 96 | 
            +
                                clip_logits["image_features"] = clip_feature_ret["clip_features"]
         | 
| 97 | 
            +
                                if "ids" in clip_logits:
         | 
| 98 | 
            +
                                    assert clip_feature_ret["ids"] == clip_logits["ids"]
         | 
| 99 | 
            +
                                else:
         | 
| 100 | 
            +
                                    clip_logits["ids"] = clip_feature_ret["ids"]
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                                with open(clip_logits_file, 'wb') as f:
         | 
| 103 | 
            +
                                    pickle.dump(clip_logits, f)
         | 
| 104 | 
            +
                                print(f"clip_feat_result saved to {clip_logits_file}",flush=True)
         | 
| 105 | 
            +
                            else:
         | 
| 106 | 
            +
                                print(f"skip {clip_logits_file}",flush=True)
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                        if args.mode == "clip_logit":
         | 
| 109 | 
            +
                        # if clip_logit:
         | 
| 110 | 
            +
                            if os.path.exists(clip_logits_file):
         | 
| 111 | 
            +
                                try:
         | 
| 112 | 
            +
                                    with open(clip_logits_file, 'rb') as f:
         | 
| 113 | 
            +
                                        clip_logits = pickle.load(f)
         | 
| 114 | 
            +
                                except:
         | 
| 115 | 
            +
                                    continue
         | 
| 116 | 
            +
                                skip = True
         | 
| 117 | 
            +
                                if args.check and clip_logits=="":
         | 
| 118 | 
            +
                                    skip = False
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                            else:
         | 
| 121 | 
            +
                                skip = False
         | 
| 122 | 
            +
                            # skip = False
         | 
| 123 | 
            +
                            if not skip:
         | 
| 124 | 
            +
                                # os.makedirs(os.path.join(save_dir, "tmp"), exist_ok=True)
         | 
| 125 | 
            +
                                with open(clip_logits_file, 'wb') as f:
         | 
| 126 | 
            +
                                    pickle.dump("", f)
         | 
| 127 | 
            +
                                try:
         | 
| 128 | 
            +
                                    clip_logits = filter.clip_logit(dataloader)
         | 
| 129 | 
            +
                                except:
         | 
| 130 | 
            +
                                    print(f"Error in clip_logit {file}",flush=True)
         | 
| 131 | 
            +
                                    continue
         | 
| 132 | 
            +
                                with open(clip_logits_file, 'wb') as f:
         | 
| 133 | 
            +
                                    pickle.dump(clip_logits, f)
         | 
| 134 | 
            +
                                print(f"clip_logits_result saved to {clip_logits_file}",flush=True)
         | 
| 135 | 
            +
                            else:
         | 
| 136 | 
            +
                                print(f"skip {clip_logits_file}",flush=True)
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                        if args.mode == "clip_logit_update":
         | 
| 139 | 
            +
                            if os.path.exists(clip_logits_file):
         | 
| 140 | 
            +
                                with open(clip_logits_file, 'rb') as f:
         | 
| 141 | 
            +
                                    clip_logits = pickle.load(f)
         | 
| 142 | 
            +
                            else:
         | 
| 143 | 
            +
                                print(f"{clip_logits_file} not exist",flush=True)
         | 
| 144 | 
            +
                                continue
         | 
| 145 | 
            +
                            if clip_logits == "":
         | 
| 146 | 
            +
                                print(f"skip {clip_logits_file}",flush=True)
         | 
| 147 | 
            +
                                continue
         | 
| 148 | 
            +
                            ret = filter.clip_logit_by_feat(clip_logits["clip_features"])
         | 
| 149 | 
            +
                            # assert (clip_logits["clip_logits"] - ret["clip_logits"]).abs().max() < 0.01
         | 
| 150 | 
            +
                            clip_logits["clip_logits"] = ret["clip_logits"]
         | 
| 151 | 
            +
                            clip_logits["text"] = ret["text"]
         | 
| 152 | 
            +
                            with open(clip_logits_file, 'wb') as f:
         | 
| 153 | 
            +
                                pickle.dump(clip_logits, f)
         | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
                        if args.mode == "clip_filt":
         | 
| 157 | 
            +
                            # if os.path.exists(clip_filt_file):
         | 
| 158 | 
            +
                            #     with open(clip_filt_file, 'rb') as f:
         | 
| 159 | 
            +
                            #         ret = pickle.load(f)
         | 
| 160 | 
            +
                            # else:
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                            if clip_logits is None:
         | 
| 163 | 
            +
                                try:
         | 
| 164 | 
            +
                                    with open(clip_logits_file, 'rb') as f:
         | 
| 165 | 
            +
                                        clip_logits = pickle.load(f)
         | 
| 166 | 
            +
                                except:
         | 
| 167 | 
            +
                                    print(f"Error in loading {clip_logits_file}",flush=True)
         | 
| 168 | 
            +
                                    error_files.append(clip_logits_file)
         | 
| 169 | 
            +
                                    continue
         | 
| 170 | 
            +
                                if clip_logits == "":
         | 
| 171 | 
            +
                                    print(f"skip {clip_logits_file}",flush=True)
         | 
| 172 | 
            +
                                    error_files.append(clip_logits_file)
         | 
| 173 | 
            +
                                    continue
         | 
| 174 | 
            +
                            clip_filt_result = filter.clip_filt(clip_logits)
         | 
| 175 | 
            +
                            with open(clip_filt_file, 'wb') as f:
         | 
| 176 | 
            +
                                pickle.dump(clip_filt_result, f)
         | 
| 177 | 
            +
                            print(f"clip_filt_result saved to {clip_filt_file}",flush=True)
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                        if args.mode == "caption_filt":
         | 
| 180 | 
            +
                            if os.path.exists(caption_filt_file):
         | 
| 181 | 
            +
                                try:
         | 
| 182 | 
            +
                                    with open(caption_filt_file, 'rb') as f:
         | 
| 183 | 
            +
                                        ret = pickle.load(f)
         | 
| 184 | 
            +
                                except:
         | 
| 185 | 
            +
                                    continue
         | 
| 186 | 
            +
                                skip = True
         | 
| 187 | 
            +
                                if args.check and ret=="":
         | 
| 188 | 
            +
                                    skip = False
         | 
| 189 | 
            +
                                    # os.remove(caption_filt_file)
         | 
| 190 | 
            +
                                    print(f"empty {caption_filt_file}",flush=True)
         | 
| 191 | 
            +
                                    # skip = True
         | 
| 192 | 
            +
                            else:
         | 
| 193 | 
            +
                                skip = False
         | 
| 194 | 
            +
                            if not skip:
         | 
| 195 | 
            +
                                with open(caption_filt_file, 'wb') as f:
         | 
| 196 | 
            +
                                    pickle.dump("", f)
         | 
| 197 | 
            +
                                # try:
         | 
| 198 | 
            +
                                ret = filter.caption_filt(dataloader)
         | 
| 199 | 
            +
                                # except:
         | 
| 200 | 
            +
                                #     print(f"Error in filtering {file}",flush=True)
         | 
| 201 | 
            +
                                #     continue
         | 
| 202 | 
            +
                                with open(caption_filt_file, 'wb') as f:
         | 
| 203 | 
            +
                                    pickle.dump(ret, f)
         | 
| 204 | 
            +
                                print(f"caption_filt_result saved to {caption_filt_file}",flush=True)
         | 
| 205 | 
            +
                            else:
         | 
| 206 | 
            +
                                print(f"skip {caption_filt_file}",flush=True)
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                        if args.mode == "caption_flit_append":
         | 
| 209 | 
            +
                            if not os.path.exists(caption_filt_file):
         | 
| 210 | 
            +
                                print(f"{caption_filt_file} not exist",flush=True)
         | 
| 211 | 
            +
                                continue
         | 
| 212 | 
            +
                            with open(caption_filt_file, 'rb') as f:
         | 
| 213 | 
            +
                                old_caption_filt_result = pickle.load(f)
         | 
| 214 | 
            +
                            skip = True
         | 
| 215 | 
            +
                            for i in filter.caption_filter.filter_prompts:
         | 
| 216 | 
            +
                                if i not in old_caption_filt_result["filter_prompts"]:
         | 
| 217 | 
            +
                                    skip = False
         | 
| 218 | 
            +
                                    break
         | 
| 219 | 
            +
                            if skip:
         | 
| 220 | 
            +
                                print(f"skip {caption_filt_file}",flush=True)
         | 
| 221 | 
            +
                                continue
         | 
| 222 | 
            +
                            old_remain_ids = old_caption_filt_result["remain_ids"]
         | 
| 223 | 
            +
                            new_dataset = SamDataset(image_folder_path, caption_folder_path, id_file=old_remain_ids, id_dict_file=id_dict_file)
         | 
| 224 | 
            +
                            new_dataloader = torch.utils.data.DataLoader(new_dataset, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_fn)
         | 
| 225 | 
            +
                            ret = filter.caption_filt(new_dataloader)
         | 
| 226 | 
            +
                            old_caption_filt_result["remain_ids"] = ret["remain_ids"]
         | 
| 227 | 
            +
                            old_caption_filt_result["filtered_ids"].extend(ret["filtered_ids"])
         | 
| 228 | 
            +
                            new_filter_count = ret["filter_count"].copy()
         | 
| 229 | 
            +
                            for i in range(len(old_caption_filt_result["filter_count"])):
         | 
| 230 | 
            +
                                new_filter_count[i] += old_caption_filt_result["filter_count"][i]
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                            old_caption_filt_result["filter_count"] = new_filter_count
         | 
| 233 | 
            +
                            old_caption_filt_result["filter_prompts"] = ret["filter_prompts"]
         | 
| 234 | 
            +
                            with open(caption_filt_file, 'wb') as f:
         | 
| 235 | 
            +
                                pickle.dump(old_caption_filt_result, f)
         | 
| 236 | 
            +
             | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
                        if args.mode == "gather_result":
         | 
| 240 | 
            +
                            with open(clip_filt_file, 'rb') as f:
         | 
| 241 | 
            +
                                clip_filt_result = pickle.load(f)
         | 
| 242 | 
            +
                            with open(caption_filt_file, 'rb') as f:
         | 
| 243 | 
            +
                                caption_filt_result = pickle.load(f)
         | 
| 244 | 
            +
                            caption_filtered_ids = [i[0] for i in caption_filt_result["filtered_ids"]]
         | 
| 245 | 
            +
                            all_filtered_id_num += len(set(clip_filt_result["filtered_ids"]) | set(caption_filtered_ids) )
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                            remain_feat_num += len(clip_filt_result["remain_ids"])
         | 
| 248 | 
            +
                            remain_caption_num += len(caption_filt_result["remain_ids"])
         | 
| 249 | 
            +
                            filter_feat_num += len(clip_filt_result["filtered_ids"])
         | 
| 250 | 
            +
                            filter_caption_num += len(caption_filtered_ids)
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                            remain_ids = set(clip_filt_result["remain_ids"]) & set(caption_filt_result["remain_ids"])
         | 
| 253 | 
            +
                            remain_ids = list(remain_ids)
         | 
| 254 | 
            +
                            remain_ids.sort()
         | 
| 255 | 
            +
                            # with open(os.path.join(save_dir, "remain_ids.pickle"), 'wb') as f:
         | 
| 256 | 
            +
                            #     pickle.dump(remain_ids, f)
         | 
| 257 | 
            +
                            # print(f"remain_ids saved to {save_dir}/remain_ids.pickle",flush=True)
         | 
| 258 | 
            +
                            all_remain_ids.extend(remain_ids)
         | 
| 259 | 
            +
                            if file.replace("_id_dict.pickle","") in val_set:
         | 
| 260 | 
            +
                                all_remain_ids_val.extend(remain_ids)
         | 
| 261 | 
            +
                            else:
         | 
| 262 | 
            +
                                all_remain_ids_train.extend(remain_ids)
         | 
| 263 | 
            +
                if args.mode == "gather_result":
         | 
| 264 | 
            +
                    print(f"filtered ids: {all_filtered_id_num}",flush=True)
         | 
| 265 | 
            +
                    print(f"remain feat num: {remain_feat_num}",flush=True)
         | 
| 266 | 
            +
                    print(f"remain caption num: {remain_caption_num}",flush=True)
         | 
| 267 | 
            +
                    print(f"filter feat num: {filter_feat_num}",flush=True)
         | 
| 268 | 
            +
                    print(f"filter caption num: {filter_caption_num}",flush=True)
         | 
| 269 | 
            +
                    all_remain_ids.sort()
         | 
| 270 | 
            +
                    with open(os.path.join(filt_dir, "all_remain_ids.pickle"), 'wb') as f:
         | 
| 271 | 
            +
                        pickle.dump(all_remain_ids, f)
         | 
| 272 | 
            +
                    with open(os.path.join(filt_dir, "all_remain_ids_train.pickle"), 'wb') as f:
         | 
| 273 | 
            +
                        pickle.dump(all_remain_ids_train, f)
         | 
| 274 | 
            +
                    with open(os.path.join(filt_dir, "all_remain_ids_val.pickle"), 'wb') as f:
         | 
| 275 | 
            +
                        pickle.dump(all_remain_ids_val, f)
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    print(f"all_remain_ids saved to {filt_dir}/all_remain_ids.pickle",flush=True)
         | 
| 278 | 
            +
                    print(f"all_remain_ids_train saved to {filt_dir}/all_remain_ids_train.pickle",flush=True)
         | 
| 279 | 
            +
                    print(f"all_remain_ids_val saved to {filt_dir}/all_remain_ids_val.pickle",flush=True)
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                print("finished",flush=True)
         | 
| 282 | 
            +
                for file in error_files:
         | 
| 283 | 
            +
                    # os.remove(file)
         | 
| 284 | 
            +
                    print(file,flush=True)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
            if __name__ == "__main__":
         | 
| 287 | 
            +
                args = parse_args()
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                log_file = "sam_filt"
         | 
| 290 | 
            +
                idx=0
         | 
| 291 | 
            +
                hostname = socket.gethostname()
         | 
| 292 | 
            +
                now_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
         | 
| 293 | 
            +
                while os.path.exists(f"{log_file}_{hostname}_check{args.check}_{now_time}_{idx}.log"):
         | 
| 294 | 
            +
                    idx+=1
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                main(args)
         | 
| 297 | 
            +
                # clip_logits_analysis()
         | 
| 298 | 
            +
             | 
| 299 | 
            +
             | 
    	
        custom_datasets/imagepair.py
    ADDED
    
    | @@ -0,0 +1,240 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import random
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch.utils.data as data
         | 
| 5 | 
            +
            from PIL import Image
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            # from tqdm import tqdm
         | 
| 9 | 
            +
            class ImageSet(data.Dataset):
         | 
| 10 | 
            +
                def __init__(self, folder , transform=None, keep_in_mem=True, caption=None):
         | 
| 11 | 
            +
                    self.path = folder
         | 
| 12 | 
            +
                    self.transform = transform
         | 
| 13 | 
            +
                    self.caption_path = None
         | 
| 14 | 
            +
                    self.images = []
         | 
| 15 | 
            +
                    self.captions = []
         | 
| 16 | 
            +
                    self.keep_in_mem = keep_in_mem
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                    if not isinstance(folder, list):
         | 
| 19 | 
            +
                        self.image_files = [file for file in os.listdir(folder) if file.endswith((".png",".jpg"))]
         | 
| 20 | 
            +
                        self.image_files.sort()
         | 
| 21 | 
            +
                    else:
         | 
| 22 | 
            +
                        self.images = folder
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    if not isinstance(caption, list):
         | 
| 25 | 
            +
                        if caption not in [None, "", "None"]:
         | 
| 26 | 
            +
                            self.caption_path = caption
         | 
| 27 | 
            +
                            self.caption_files = [os.path.join(caption, file.replace(".png", ".txt").replace(".jpg", ".txt")) for file in self.image_files]
         | 
| 28 | 
            +
                            self.caption_files.sort()
         | 
| 29 | 
            +
                    else:
         | 
| 30 | 
            +
                        self.caption_path = True
         | 
| 31 | 
            +
                        self.captions = caption
         | 
| 32 | 
            +
                    # get all the image files png/jpg
         | 
| 33 | 
            +
             | 
| 34 | 
            +
             | 
| 35 | 
            +
                    if keep_in_mem:
         | 
| 36 | 
            +
                        if len(self.images) == 0:
         | 
| 37 | 
            +
                            for file in self.image_files:
         | 
| 38 | 
            +
                                img = self.load_image(os.path.join(self.path, file))
         | 
| 39 | 
            +
                                self.images.append(img)
         | 
| 40 | 
            +
                        if len(self.captions) == 0:
         | 
| 41 | 
            +
                            if self.caption_path is not None:
         | 
| 42 | 
            +
                                self.captions = []
         | 
| 43 | 
            +
                                for file in self.caption_files:
         | 
| 44 | 
            +
                                    caption = self.load_caption(file)
         | 
| 45 | 
            +
                                    self.captions.append(caption)
         | 
| 46 | 
            +
                    else:
         | 
| 47 | 
            +
                        self.images = None
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def limit_num(self, n):
         | 
| 50 | 
            +
                    raise NotImplementedError
         | 
| 51 | 
            +
                    assert n <= len(self), f"n should be less than the length of the dataset {len(self)}"
         | 
| 52 | 
            +
                    self.image_files = self.image_files[:n]
         | 
| 53 | 
            +
                    self.caption_files = self.caption_files[:n]
         | 
| 54 | 
            +
                    if self.keep_in_mem:
         | 
| 55 | 
            +
                        self.images = self.images[:n]
         | 
| 56 | 
            +
                        self.captions = self.captions[:n]
         | 
| 57 | 
            +
                    print(f"Dataset limited to {n}")
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def __len__(self):
         | 
| 60 | 
            +
                    if len(self.images) != 0:
         | 
| 61 | 
            +
                        return len(self.images)
         | 
| 62 | 
            +
                    else:
         | 
| 63 | 
            +
                        return len(self.image_files)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def load_image(self, path):
         | 
| 66 | 
            +
                    with open(path, 'rb') as f:
         | 
| 67 | 
            +
                        img = Image.open(f).convert('RGB')
         | 
| 68 | 
            +
                    return img
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def load_caption(self, path):
         | 
| 71 | 
            +
                    with open(path, 'r') as f:
         | 
| 72 | 
            +
                        caption = f.readlines()
         | 
| 73 | 
            +
                    caption = [line.strip() for line in caption if len(line.strip()) > 0]
         | 
| 74 | 
            +
                    return caption
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                def __getitem__(self, index):
         | 
| 77 | 
            +
                    if len(self.images) != 0:
         | 
| 78 | 
            +
                        img = self.images[index]
         | 
| 79 | 
            +
                    else:
         | 
| 80 | 
            +
                        img = self.load_image(os.path.join(self.path, self.image_files[index]))
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    # if self.transform is not None:
         | 
| 83 | 
            +
                    #     img = self.transform(img)
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    if self.caption_path is not None or len(self.captions) != 0:
         | 
| 86 | 
            +
                        if len(self.captions) != 0:
         | 
| 87 | 
            +
                            caption = self.captions[index]
         | 
| 88 | 
            +
                        else:
         | 
| 89 | 
            +
                            caption = self.load_caption(self.caption_files[index])
         | 
| 90 | 
            +
                        ret= {"image": img, "caption": caption, "id": index}
         | 
| 91 | 
            +
                    else:
         | 
| 92 | 
            +
                        ret= {"image": img, "id": index}
         | 
| 93 | 
            +
                    if self.transform is not None:
         | 
| 94 | 
            +
                        ret = self.transform(ret)
         | 
| 95 | 
            +
                    return ret
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                def subsample(self, n: int = 10):
         | 
| 98 | 
            +
                    if n is None or n == -1:
         | 
| 99 | 
            +
                        return self
         | 
| 100 | 
            +
                    ori_len = len(self)
         | 
| 101 | 
            +
                    assert n <= ori_len
         | 
| 102 | 
            +
                    # equal interval subsample
         | 
| 103 | 
            +
                    ids = self.image_files[::ori_len // n][:n]
         | 
| 104 | 
            +
                    self.image_files = ids
         | 
| 105 | 
            +
                    if self.keep_in_mem:
         | 
| 106 | 
            +
                        self.images = self.images[::ori_len // n][:n]
         | 
| 107 | 
            +
                    print(f"Dataset subsampled from {ori_len} to {len(self)}")
         | 
| 108 | 
            +
                    return self
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def with_transform(self, transform):
         | 
| 111 | 
            +
                    self.transform = transform
         | 
| 112 | 
            +
                    return self
         | 
| 113 | 
            +
                @staticmethod
         | 
| 114 | 
            +
                def collate_fn(examples):
         | 
| 115 | 
            +
                    images = [example["image"] for example in examples]
         | 
| 116 | 
            +
                    ids = [example["id"] for example in examples]
         | 
| 117 | 
            +
                    if "caption" in examples[0]:
         | 
| 118 | 
            +
                        captions = [random.choice(example["caption"]) for example in examples]
         | 
| 119 | 
            +
                        return {"images": images, "captions": captions, "id": ids}
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        return {"images": images, "id": ids}
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
            class ImagePair(ImageSet):
         | 
| 125 | 
            +
                def __init__(self, folder1, folder2, transform=None, keep_in_mem=True):
         | 
| 126 | 
            +
                    self.path1 = folder1
         | 
| 127 | 
            +
                    self.path2 = folder2
         | 
| 128 | 
            +
                    self.transform = transform
         | 
| 129 | 
            +
                    # get all the image files png/jpg
         | 
| 130 | 
            +
                    self.image_files = [file for file in os.listdir(folder1) if file.endswith(".png") or file.endswith(".jpg")]
         | 
| 131 | 
            +
                    self.image_files.sort()
         | 
| 132 | 
            +
                    self.keep_in_mem = keep_in_mem
         | 
| 133 | 
            +
                    if keep_in_mem:
         | 
| 134 | 
            +
                        self.images = []
         | 
| 135 | 
            +
                        for file in self.image_files:
         | 
| 136 | 
            +
                            img1 = self.load_image(os.path.join(self.path1, file))
         | 
| 137 | 
            +
                            img2 = self.load_image(os.path.join(self.path2, file))
         | 
| 138 | 
            +
                            self.images.append((img1, img2))
         | 
| 139 | 
            +
                    else:
         | 
| 140 | 
            +
                        self.images = None
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def __getitem__(self, index):
         | 
| 143 | 
            +
                    if self.keep_in_mem:
         | 
| 144 | 
            +
                        img1, img2 = self.images[index]
         | 
| 145 | 
            +
                    else:
         | 
| 146 | 
            +
                        img1 = self.load_image(os.path.join(self.path1, self.image_files[index]))
         | 
| 147 | 
            +
                        img2 = self.load_image(os.path.join(self.path2, self.image_files[index]))
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                    if self.transform is not None:
         | 
| 150 | 
            +
                        img1 = self.transform(img1)
         | 
| 151 | 
            +
                        img2 = self.transform(img2)
         | 
| 152 | 
            +
                    return {"image1": img1, "image2": img2, "id": index}
         | 
| 153 | 
            +
             | 
| 154 | 
            +
             | 
| 155 | 
            +
             | 
| 156 | 
            +
                @staticmethod
         | 
| 157 | 
            +
                def collate_fn(examples):
         | 
| 158 | 
            +
                    images1 = [example["image1"] for example in examples]
         | 
| 159 | 
            +
                    images2 = [example["image2"] for example in examples]
         | 
| 160 | 
            +
                    # images1 = torch.stack(images1)
         | 
| 161 | 
            +
                    # images2 = torch.stack(images2)
         | 
| 162 | 
            +
                    ids = [example["id"] for example in examples]
         | 
| 163 | 
            +
                    return {"image1": images1, "image2": images2, "id": ids}
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def push_to_huggingface(self, hug_folder):
         | 
| 166 | 
            +
                    from datasets import Dataset
         | 
| 167 | 
            +
                    from datasets import Image as HugImage
         | 
| 168 | 
            +
                    photo_path = [os.path.join(self.path1, file) for file in self.image_files]
         | 
| 169 | 
            +
                    sketch_path = [os.path.join(self.path2, file) for file in self.image_files]
         | 
| 170 | 
            +
                    dataset = Dataset.from_dict({"photo": photo_path, "sketch": sketch_path, "file_name": self.image_files})
         | 
| 171 | 
            +
                    dataset = dataset.cast_column("photo", HugImage())
         | 
| 172 | 
            +
                    dataset = dataset.cast_column("sketch", HugImage())
         | 
| 173 | 
            +
                    dataset.push_to_hub(hug_folder, private=True)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
            class ImageClass(ImageSet):
         | 
| 176 | 
            +
                def __init__(self, folders: list, transform=None, keep_in_mem=True):
         | 
| 177 | 
            +
                    self.paths = folders
         | 
| 178 | 
            +
                    self.transform = transform
         | 
| 179 | 
            +
                    # get all the image files png/jpg
         | 
| 180 | 
            +
                    self.image_files = []
         | 
| 181 | 
            +
                    self.keep_in_mem = keep_in_mem
         | 
| 182 | 
            +
                    for i, folder in enumerate(folders):
         | 
| 183 | 
            +
                        self.image_files+=[(os.path.join(folder, file), i) for file in os.listdir(folder) if file.endswith(".png") or file.endswith(".jpg")]
         | 
| 184 | 
            +
                    if keep_in_mem:
         | 
| 185 | 
            +
                        self.images = []
         | 
| 186 | 
            +
                        print("Loading images to memory")
         | 
| 187 | 
            +
                        for file in self.image_files:
         | 
| 188 | 
            +
                            img = self.load_image(file[0])
         | 
| 189 | 
            +
                            self.images.append((img, file[1]))
         | 
| 190 | 
            +
                        print("Loading images to memory done")
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        self.images = None
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def __getitem__(self, index):
         | 
| 195 | 
            +
                    if self.keep_in_mem:
         | 
| 196 | 
            +
                        img, label = self.images[index]
         | 
| 197 | 
            +
                    else:
         | 
| 198 | 
            +
                        img_path, label = self.image_files[index]
         | 
| 199 | 
            +
                        img = self.load_image(img_path)
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    if self.transform is not None:
         | 
| 202 | 
            +
                        img = self.transform(img)
         | 
| 203 | 
            +
                    return {"image": img, "label": label, "id": index}
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                @staticmethod
         | 
| 206 | 
            +
                def collate_fn(examples):
         | 
| 207 | 
            +
                    images = [example["image"] for example in examples]
         | 
| 208 | 
            +
                    labels = [example["label"] for example in examples]
         | 
| 209 | 
            +
                    ids = [example["id"] for example in examples]
         | 
| 210 | 
            +
                    return {"images": images, "labels":labels, "id": ids}
         | 
| 211 | 
            +
             | 
| 212 | 
            +
             | 
| 213 | 
            +
            if __name__ == "__main__":
         | 
| 214 | 
            +
                # dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_50",
         | 
| 215 | 
            +
                #                     "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_50",keep_in_mem=False)
         | 
| 216 | 
            +
                # dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-50")
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
            +
             | 
| 220 | 
            +
                dataset = ImagePair("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500",
         | 
| 221 | 
            +
                                    "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500",
         | 
| 222 | 
            +
                                    keep_in_mem=True)
         | 
| 223 | 
            +
                # dataset.push_to_huggingface("rhfeiyang/photo-sketch-pair-500")
         | 
| 224 | 
            +
                # ret = dataset[0]
         | 
| 225 | 
            +
                # print(len(dataset))
         | 
| 226 | 
            +
                import torch
         | 
| 227 | 
            +
                from torchvision import transforms
         | 
| 228 | 
            +
                train_transforms = transforms.Compose(
         | 
| 229 | 
            +
                    [
         | 
| 230 | 
            +
                        transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
         | 
| 231 | 
            +
                        transforms.CenterCrop(256),
         | 
| 232 | 
            +
                        transforms.RandomHorizontalFlip(),
         | 
| 233 | 
            +
                        transforms.ToTensor(),
         | 
| 234 | 
            +
                        transforms.Normalize([0.5], [0.5]),
         | 
| 235 | 
            +
                    ]
         | 
| 236 | 
            +
                )
         | 
| 237 | 
            +
                dataset = dataset.with_transform(train_transforms)
         | 
| 238 | 
            +
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=ImagePair.collate_fn)
         | 
| 239 | 
            +
                ret = dataloader.__iter__().__next__()
         | 
| 240 | 
            +
                pass
         | 
    	
        custom_datasets/lhq.py
    ADDED
    
    | @@ -0,0 +1,127 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            import pickle
         | 
| 4 | 
            +
            import random
         | 
| 5 | 
            +
            import shutil
         | 
| 6 | 
            +
            from torch.utils.data import Dataset
         | 
| 7 | 
            +
            from torchvision import transforms
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            class LhqDataset(Dataset):
         | 
| 11 | 
            +
                def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "clip_dissection/lhq/idx/subsample_100.pickle", transforms: transforms = None,
         | 
| 12 | 
            +
                             get_img=True,
         | 
| 13 | 
            +
                             get_cap=True,):
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                    if isinstance(id_file, list):
         | 
| 16 | 
            +
                        self.ids = id_file
         | 
| 17 | 
            +
                    elif isinstance(id_file, str):
         | 
| 18 | 
            +
                        with open(id_file, 'rb') as f:
         | 
| 19 | 
            +
                            print(f"Loading ids from {id_file}", flush=True)
         | 
| 20 | 
            +
                            self.ids = pickle.load(f)
         | 
| 21 | 
            +
                            print(f"Loaded ids from {id_file}", flush=True)
         | 
| 22 | 
            +
                    self.image_folder_path = image_folder_path
         | 
| 23 | 
            +
                    self.caption_folder_path = caption_folder_path
         | 
| 24 | 
            +
                    self.transforms = transforms
         | 
| 25 | 
            +
                    self.column_names = ["image", "text"]
         | 
| 26 | 
            +
                    self.get_img = get_img
         | 
| 27 | 
            +
                    self.get_cap = get_cap
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __len__(self):
         | 
| 30 | 
            +
                    return len(self.ids)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def __getitem__(self, index: int):
         | 
| 33 | 
            +
                    id = self.ids[index]
         | 
| 34 | 
            +
                    ret={"id":id}
         | 
| 35 | 
            +
                    if self.get_img:
         | 
| 36 | 
            +
                        image = self._load_image(id)
         | 
| 37 | 
            +
                        ret["image"]=image
         | 
| 38 | 
            +
                    if self.get_cap:
         | 
| 39 | 
            +
                        target = self._load_caption(id)
         | 
| 40 | 
            +
                        ret["caption"]=[target]
         | 
| 41 | 
            +
                    if self.transforms is not None:
         | 
| 42 | 
            +
                        ret = self.transforms(ret)
         | 
| 43 | 
            +
                    return ret
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                def _load_image(self, id: int):
         | 
| 46 | 
            +
                    image_path = f"{self.image_folder_path}/{id}.jpg"
         | 
| 47 | 
            +
                    with open(image_path, 'rb') as f:
         | 
| 48 | 
            +
                        img = Image.open(f).convert("RGB")
         | 
| 49 | 
            +
                    return img
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def _load_caption(self, id: int):
         | 
| 52 | 
            +
                    caption_path = f"{self.caption_folder_path}/{id}.txt"
         | 
| 53 | 
            +
                    with open(caption_path, 'r') as f:
         | 
| 54 | 
            +
                        caption_file = f.read()
         | 
| 55 | 
            +
                    caption = []
         | 
| 56 | 
            +
                    for line in caption_file.split("\n"):
         | 
| 57 | 
            +
                        line = line.strip()
         | 
| 58 | 
            +
                        if len(line) > 0:
         | 
| 59 | 
            +
                            caption.append(line)
         | 
| 60 | 
            +
                    return caption
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def subsample(self, n: int = 10000):
         | 
| 63 | 
            +
                    if n is None or n == -1:
         | 
| 64 | 
            +
                        return self
         | 
| 65 | 
            +
                    ori_len = len(self)
         | 
| 66 | 
            +
                    assert n <= ori_len
         | 
| 67 | 
            +
                    # equal interval subsample
         | 
| 68 | 
            +
                    ids = self.ids[::ori_len // n][:n]
         | 
| 69 | 
            +
                    self.ids = ids
         | 
| 70 | 
            +
                    print(f"LHQ dataset subsampled from {ori_len} to {len(self)}")
         | 
| 71 | 
            +
                    return self
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def with_transform(self, transform):
         | 
| 74 | 
            +
                    self.transforms = transform
         | 
| 75 | 
            +
                    return self
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            def generate_idx(data_folder = "/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/", save_path = "/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle"):
         | 
| 79 | 
            +
                all_ids = os.listdir(data_folder)
         | 
| 80 | 
            +
                all_ids = [i.split(".")[0] for i in all_ids if i.endswith(".jpg") or i.endswith(".png")]
         | 
| 81 | 
            +
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
         | 
| 82 | 
            +
                pickle.dump(all_ids, open(f"{save_path}", "wb"))
         | 
| 83 | 
            +
                print("all_ids generated")
         | 
| 84 | 
            +
                return all_ids
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            def random_sample(all_ids, sample_num = 110, save_root = "/data/vision/torralba/clip_dissection/huiren/lhq/subsample"):
         | 
| 87 | 
            +
                chosen_id = random.sample(all_ids, sample_num)
         | 
| 88 | 
            +
                save_dir = f"{save_root}/{sample_num}"
         | 
| 89 | 
            +
                os.makedirs(save_dir, exist_ok=True)
         | 
| 90 | 
            +
                for id in chosen_id:
         | 
| 91 | 
            +
                    img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
         | 
| 92 | 
            +
                    shutil.copy(img_path, save_dir)
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                return chosen_id
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            if __name__ == "__main__":
         | 
| 97 | 
            +
                # all_ids = generate_idx()
         | 
| 98 | 
            +
                # with open("/data/vision/torralba/clip_dissection/huiren/lhq/idx/all_ids.pickle", "rb") as f:
         | 
| 99 | 
            +
                #     all_ids = pickle.load(f)
         | 
| 100 | 
            +
                # # random_sample(all_ids, 1)
         | 
| 101 | 
            +
                #
         | 
| 102 | 
            +
                # # generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/100",
         | 
| 103 | 
            +
                # #              save_path="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle")
         | 
| 104 | 
            +
                #
         | 
| 105 | 
            +
                # # lhq 500
         | 
| 106 | 
            +
                # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_100.pickle", "rb") as f:
         | 
| 107 | 
            +
                #     lhq_100_idx = pickle.load(f)
         | 
| 108 | 
            +
                #
         | 
| 109 | 
            +
                # extra_idx = set(all_ids) - set(lhq_100_idx)
         | 
| 110 | 
            +
                # add_idx = random.sample(extra_idx, 400)
         | 
| 111 | 
            +
                # lhq_500_idx = lhq_100_idx + add_idx
         | 
| 112 | 
            +
                # with open("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/idx/subsample_500.pickle", "wb") as f:
         | 
| 113 | 
            +
                #     pickle.dump(lhq_500_idx, f)
         | 
| 114 | 
            +
                # save_dir = "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/500"
         | 
| 115 | 
            +
                # os.makedirs(save_dir, exist_ok=True)
         | 
| 116 | 
            +
                # for id in lhq_500_idx:
         | 
| 117 | 
            +
                #     img_path = f"/data/vision/torralba/clip_dissection/huiren/lhq/lhq_1024_jpg/lhq_1024_jpg/{id}.jpg"
         | 
| 118 | 
            +
                #     # softlink
         | 
| 119 | 
            +
                #     os.symlink(img_path, os.path.join(save_dir, f"{id}.jpg"))
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                # lhq9
         | 
| 122 | 
            +
                all_ids = generate_idx(data_folder="/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/lhq/subsample/9",
         | 
| 123 | 
            +
                                       save_path="/data/vision/torralba/clip_dissection/huiren/lhq/idx/subsample_9.pickle")
         | 
| 124 | 
            +
                print(all_ids)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
                
         | 
    	
        custom_datasets/mypath.py
    ADDED
    
    | @@ -0,0 +1,29 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class MyPath(object):
         | 
| 5 | 
            +
                @staticmethod
         | 
| 6 | 
            +
                def db_root_dir(database=''):
         | 
| 7 | 
            +
                    coco_root = "/data/vision/torralba/datasets/coco_2017"
         | 
| 8 | 
            +
                    sam_caption_root = "/vision-nfs/torralba/datasets/vision/sam/captions"
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                    root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
         | 
| 11 | 
            +
                    map={
         | 
| 12 | 
            +
                        "coco_train": f"{coco_root}/train2017/",
         | 
| 13 | 
            +
                        "coco_caption_train": f"{coco_root}/annotations/captions_train2017.json",
         | 
| 14 | 
            +
                        "coco_val": f"{coco_root}/val2017/",
         | 
| 15 | 
            +
                        "coco_caption_val": f"{coco_root}/annotations/captions_val2017.json",
         | 
| 16 | 
            +
                        "sam_images": "/vision-nfs/torralba/datasets/vision/sam/images",
         | 
| 17 | 
            +
                        "sam_captions": sam_caption_root,
         | 
| 18 | 
            +
                        "sam_whole_filtered_ids_train": "data/filtered_sam/all_remain_ids_train.pickle",
         | 
| 19 | 
            +
                        "sam_whole_filtered_ids_val": "data/filtered_sam/all_remain_ids_val.pickle",
         | 
| 20 | 
            +
                        "sam_id_dict": "data/filtered_sam/all_id_dict.pickle",
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                        "lhq_ids_sub500": "data/LHQ500_caption/idx/subsample_500.pickle",
         | 
| 23 | 
            +
                        "lhq_images": "data/LHQ500_caption/subsample_500",
         | 
| 24 | 
            +
                        "lhq_captions": "data/LHQ500_caption/captions",
         | 
| 25 | 
            +
                    }
         | 
| 26 | 
            +
                    ret = map.get(database, None)
         | 
| 27 | 
            +
                    if ret is None:
         | 
| 28 | 
            +
                        raise NotImplementedError
         | 
| 29 | 
            +
                    return ret
         | 
    	
        custom_datasets/sam.py
    ADDED
    
    | @@ -0,0 +1,160 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import os.path
         | 
| 3 | 
            +
            import sys
         | 
| 4 | 
            +
            from typing import Any, Callable, List, Optional, Tuple
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import tqdm
         | 
| 7 | 
            +
            from PIL import Image
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            from torch.utils.data import Dataset
         | 
| 10 | 
            +
            import pickle
         | 
| 11 | 
            +
            from torchvision import transforms
         | 
| 12 | 
            +
            # import torch
         | 
| 13 | 
            +
            # import torchvision
         | 
| 14 | 
            +
            # import re
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            class SamDataset(Dataset):
         | 
| 18 | 
            +
                def __init__(self, image_folder_path:str, caption_folder_path:str, id_file:str = "data/sam/clip_filtered_ids.pickle",id_dict_file:str =None , transforms: Optional[Callable] = None,
         | 
| 19 | 
            +
                             resolution=None,
         | 
| 20 | 
            +
                             get_img=True,
         | 
| 21 | 
            +
                             get_cap=True,):
         | 
| 22 | 
            +
                    if id_dict_file is not None:
         | 
| 23 | 
            +
                        with open(id_dict_file, 'rb') as f:
         | 
| 24 | 
            +
                            print(f"Loading id_dict from {id_dict_file}", flush=True)
         | 
| 25 | 
            +
                            self.id_dict = pickle.load(f)
         | 
| 26 | 
            +
                            print(f"Loaded id_dict from {id_dict_file}", flush=True)
         | 
| 27 | 
            +
                    else:
         | 
| 28 | 
            +
                        self.id_dict = None
         | 
| 29 | 
            +
                    if isinstance(id_file, list):
         | 
| 30 | 
            +
                        self.ids = id_file
         | 
| 31 | 
            +
                    elif isinstance(id_file, str):
         | 
| 32 | 
            +
                        with open(id_file, 'rb') as f:
         | 
| 33 | 
            +
                            print(f"Loading ids from {id_file}", flush=True)
         | 
| 34 | 
            +
                            self.ids = pickle.load(f)
         | 
| 35 | 
            +
                            print(f"Loaded ids from {id_file}", flush=True)
         | 
| 36 | 
            +
                    self.resolution = resolution
         | 
| 37 | 
            +
                    self.ori_image_folder_path = image_folder_path
         | 
| 38 | 
            +
                    if self.resolution is not None:
         | 
| 39 | 
            +
                        if os.path.exists("/var/jomat/datasets/"):
         | 
| 40 | 
            +
                            # self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
         | 
| 41 | 
            +
                            self.image_folder_path = f"{image_folder_path}_{resolution}"
         | 
| 42 | 
            +
                        else:
         | 
| 43 | 
            +
                            self.image_folder_path = f"{image_folder_path}_{resolution}"
         | 
| 44 | 
            +
                        os.makedirs(self.image_folder_path, exist_ok=True)
         | 
| 45 | 
            +
                    else:
         | 
| 46 | 
            +
                        self.image_folder_path = image_folder_path
         | 
| 47 | 
            +
                    self.caption_folder_path = caption_folder_path
         | 
| 48 | 
            +
                    self.transforms = transforms
         | 
| 49 | 
            +
                    self.column_names = ["image", "text"]
         | 
| 50 | 
            +
                    self.get_img = get_img
         | 
| 51 | 
            +
                    self.get_cap = get_cap
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def __len__(self):
         | 
| 54 | 
            +
                    # return 100
         | 
| 55 | 
            +
                    return len(self.ids)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __getitem__(self, index: int):
         | 
| 58 | 
            +
                    id = self.ids[index]
         | 
| 59 | 
            +
                    ret={"id":id}
         | 
| 60 | 
            +
                    try:
         | 
| 61 | 
            +
                        # if index == 1:
         | 
| 62 | 
            +
                        #     raise Exception("test")
         | 
| 63 | 
            +
                        if self.get_img:
         | 
| 64 | 
            +
                            image = self._load_image(id)
         | 
| 65 | 
            +
                            ret["image"]=image
         | 
| 66 | 
            +
                        if self.get_cap:
         | 
| 67 | 
            +
                            target = self._load_caption(id)
         | 
| 68 | 
            +
                            ret["text"] = [target]
         | 
| 69 | 
            +
                        if self.transforms is not None:
         | 
| 70 | 
            +
                            ret = self.transforms(ret)
         | 
| 71 | 
            +
                        return ret
         | 
| 72 | 
            +
                    except Exception as e:
         | 
| 73 | 
            +
                        raise e
         | 
| 74 | 
            +
                        print(f"Error loading image and caption for id {id}, error: {e}, redirecting to index 0", flush=True)
         | 
| 75 | 
            +
                        ret = self[0]
         | 
| 76 | 
            +
                        return ret
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def define_resolution(self, resolution: int):
         | 
| 79 | 
            +
                    self.resolution = resolution
         | 
| 80 | 
            +
                    if os.path.exists("/var/jomat/datasets/"):
         | 
| 81 | 
            +
                        self.image_folder_path = f"/var/jomat/datasets/SAM_{resolution}"
         | 
| 82 | 
            +
                        # self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
         | 
| 83 | 
            +
                    else:
         | 
| 84 | 
            +
                        self.image_folder_path = f"{self.ori_image_folder_path}_{resolution}"
         | 
| 85 | 
            +
                    print(f"SamDataset resolution defined to {resolution}, new image folder path: {self.image_folder_path}")
         | 
| 86 | 
            +
                def _load_image(self, id: int) -> Image.Image:
         | 
| 87 | 
            +
                    if self.id_dict is not None:
         | 
| 88 | 
            +
                        subfolder = self.id_dict[id]
         | 
| 89 | 
            +
                        image_path = f"{self.image_folder_path}/{subfolder}/sa_{id}.jpg"
         | 
| 90 | 
            +
                    else:
         | 
| 91 | 
            +
                        image_path = f"{self.image_folder_path}/sa_{id}.jpg"
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                    try:
         | 
| 94 | 
            +
                        with open(image_path, 'rb') as f:
         | 
| 95 | 
            +
                            img = Image.open(f).convert("RGB")
         | 
| 96 | 
            +
                        # return img
         | 
| 97 | 
            +
                    except:
         | 
| 98 | 
            +
                        # load original image
         | 
| 99 | 
            +
                        if self.id_dict is not None:
         | 
| 100 | 
            +
                            subfolder = self.id_dict[id]
         | 
| 101 | 
            +
                            ori_image_path = f"{self.ori_image_folder_path}/{subfolder}/sa_{id}.jpg"
         | 
| 102 | 
            +
                        else:
         | 
| 103 | 
            +
                            ori_image_path = f"{self.ori_image_folder_path}/sa_{id}.jpg"
         | 
| 104 | 
            +
                        assert os.path.exists(ori_image_path)
         | 
| 105 | 
            +
                        with open(ori_image_path, 'rb') as f:
         | 
| 106 | 
            +
                            img = Image.open(f).convert("RGB")
         | 
| 107 | 
            +
                        # resize image keep aspect ratio
         | 
| 108 | 
            +
                        if self.resolution is not None:
         | 
| 109 | 
            +
                            img = transforms.Resize(self.resolution, interpolation=transforms.InterpolationMode.BICUBIC)(img)
         | 
| 110 | 
            +
                        # write image
         | 
| 111 | 
            +
                        os.makedirs(os.path.dirname(image_path), exist_ok=True)
         | 
| 112 | 
            +
                        img.save(image_path)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                    return img
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                
         | 
| 117 | 
            +
                def _load_caption(self, id: int):
         | 
| 118 | 
            +
                    caption_path = f"{self.caption_folder_path}/sa_{id}.txt"
         | 
| 119 | 
            +
                    if not os.path.exists(caption_path):
         | 
| 120 | 
            +
                        return None
         | 
| 121 | 
            +
                    try:
         | 
| 122 | 
            +
                        with open(caption_path, 'r', encoding="utf-8") as f:
         | 
| 123 | 
            +
                            content = f.read()
         | 
| 124 | 
            +
                    except Exception as e:
         | 
| 125 | 
            +
                        raise e
         | 
| 126 | 
            +
                        print(f"Error reading caption file {caption_path}, error: {e}")
         | 
| 127 | 
            +
                        return None
         | 
| 128 | 
            +
                    sentences = content.split('.')
         | 
| 129 | 
            +
                    # remove empty sentences and sentences with "black and white"(too many false prediction)
         | 
| 130 | 
            +
                    sentences = [sentence.strip() for sentence in sentences if sentence.strip() and "black and white" not in sentence]
         | 
| 131 | 
            +
                    # join sentence
         | 
| 132 | 
            +
                    sentences = ". ".join(sentences)
         | 
| 133 | 
            +
                    if len(sentences) > 0 and sentences[-1] != '.':
         | 
| 134 | 
            +
                        sentences += '.'
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    return sentences
         | 
| 137 | 
            +
                
         | 
| 138 | 
            +
                def with_transform(self, transform):
         | 
| 139 | 
            +
                    self.transforms = transform
         | 
| 140 | 
            +
                    return self
         | 
| 141 | 
            +
             | 
| 142 | 
            +
                def subsample(self, n: int = 10000):
         | 
| 143 | 
            +
                    if n is None or n == -1:
         | 
| 144 | 
            +
                        return self
         | 
| 145 | 
            +
                    ori_len = len(self)
         | 
| 146 | 
            +
                    assert n <= ori_len
         | 
| 147 | 
            +
                    # equal interval subsample
         | 
| 148 | 
            +
                    ids = self.ids[::ori_len // n][:n]
         | 
| 149 | 
            +
                    self.ids = ids
         | 
| 150 | 
            +
                    print(f"SAM dataset subsampled from {ori_len} to {len(self)}")
         | 
| 151 | 
            +
                    return self
         | 
| 152 | 
            +
             | 
| 153 | 
            +
             | 
| 154 | 
            +
            if __name__ == "__main__":
         | 
| 155 | 
            +
                # sam_filt(caption_filt=False, clip_filt=False, clip_logit=True)
         | 
| 156 | 
            +
                from custom_datasets.sam_caption.mypath import MyPath
         | 
| 157 | 
            +
                dataset = SamDataset(image_folder_path=MyPath.db_root_dir("sam_images"), caption_folder_path=MyPath.db_root_dir("sam_captions"), id_file=MyPath.db_root_dir("sam_whole_filtered_ids_train"), id_dict_file=MyPath.db_root_dir("sam_id_dict"))
         | 
| 158 | 
            +
                dataset.get_img = False
         | 
| 159 | 
            +
                for i in tqdm.tqdm(dataset):
         | 
| 160 | 
            +
                    a=i['text']
         | 
    	
        data/Art_adapters/albert-gleizes_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:1802d12e4d9526eedb89d99f69051849f14774da3c73ebc9b1393c2b13f17022
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/andre-derain_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:4c39b39f32ff88dfed978ccc651715ade9edfd901d529adbeb5eedb715b8e159
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/andy_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:fd7764b19a2b4513b3c22f1607d72daa63c4ace97ea803e29e2bcf3f13bab2e8
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/camille-corot_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:426c2e4a3bfc26f7fdcc3e82989d717fa5fc6e732cd9df9f8bb293ab72cacfa5
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/gerhard-richter_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8be8ef590baceb2bdfac8b25976df88fa7baa1a9c718ed16aa4fa8fa247bb421
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/henri-matisse_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:212f0f16ae84c0bae96e213a0b0d5f4309209b332d48cbaa1748b5cdcfb3238a
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/jackson-pollock_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:5cff54e3e7c544577dbc39d7015a89c4786cd012cf944d0b9db334c1a1d7e30b
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/joan-miro_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c26bdb5bfba85b4eb00631eda149912ba557935773842f95c0596999f799a2b4
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/kandinsky_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:24b33205841d9b09c0076b4ba295be29d94677e69b7269465897bbf059a40454
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/katsushika-hokusai_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:b34b75325c3fd0353b55f390027a32a98f771df7d2fb21dbd8bce81a12ba59e9
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/klimt_subset3/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:7457f14af7c77f98675063582b35317963d46e942459575d38b5996ed190c58f
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/m.c.-escher_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:c6df86764f4d4ceec0bd6124a74a51c36665c8491511a5488737b9a64300b97b
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/monet_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:6a9ba0305edca3286258a06023b97914b850fbc8b4f5a14769537f9a01ef33f1
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/picasso_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:8ce7899c19b32dacd2dc46090fd3429495a2230c173bcd96149236d27b5151fd
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/roy-lichtenstein_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:5ac428a5d0fb136b79eec2349fbcbd99dfac2315c0a7f54d7985299b60b6f66f
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/van_gogh_subset1/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:3ca866dd868fb89a1180bb140dfaf1e48701993c8fa173d70c56c60c9af8d8fb
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/Art_adapters/walter-battiss_subset2/adapter_alpha1.0_rank1_all_up_1000steps.pt
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            version https://git-lfs.github.com/spec/v1
         | 
| 2 | 
            +
            oid sha256:41cad39d7b6e1873cfef85be478851820f5dc80cd7ce11afe2bfa3584662e3ac
         | 
| 3 | 
            +
            size 2187129
         | 
    	
        data/unsafe.png
    ADDED
    
    |   | 
    	
        hf_demo.py
    ADDED
    
    | @@ -0,0 +1,147 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import gradio as gr
         | 
| 5 | 
            +
            from diffusers import DiffusionPipeline
         | 
| 6 | 
            +
            import matplotlib.pyplot as plt
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            from PIL import Image
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            device = "cuda" if torch.cuda.is_available() else "cpu"
         | 
| 13 | 
            +
            pipe = DiffusionPipeline.from_pretrained("rhfeiyang/art-free-diffusion-v1",).to(device)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            from inference import get_lora_network, inference, get_validation_dataloader
         | 
| 16 | 
            +
            lora_map = {
         | 
| 17 | 
            +
                "None": "None",
         | 
| 18 | 
            +
                "Andre Derain": "andre-derain_subset1",
         | 
| 19 | 
            +
                "Vincent van Gogh": "van_gogh_subset1",
         | 
| 20 | 
            +
                "Andy Warhol": "andy_subset1",
         | 
| 21 | 
            +
                "Walter Battiss": "walter-battiss_subset2",
         | 
| 22 | 
            +
                "Camille Corot": "camille-corot_subset1",
         | 
| 23 | 
            +
                "Claude Monet": "monet_subset2",
         | 
| 24 | 
            +
                "Pablo Picasso": "picasso_subset1",
         | 
| 25 | 
            +
                "Jackson Pollock": "jackson-pollock_subset1",
         | 
| 26 | 
            +
                "Gerhard Richter": "gerhard-richter_subset1",
         | 
| 27 | 
            +
                "M.C. Escher": "m.c.-escher_subset1",
         | 
| 28 | 
            +
                "Albert Gleizes": "albert-gleizes_subset1",
         | 
| 29 | 
            +
                "Hokusai": "katsushika-hokusai_subset1",
         | 
| 30 | 
            +
                "Wassily Kandinsky": "kandinsky_subset1",
         | 
| 31 | 
            +
                "Gustav Klimt": "klimt_subset3",
         | 
| 32 | 
            +
                "Roy Lichtenstein": "roy-lichtenstein_subset1",
         | 
| 33 | 
            +
                "Henri Matisse": "henri-matisse_subset1",
         | 
| 34 | 
            +
                "Joan Miro": "joan-miro_subset2",
         | 
| 35 | 
            +
            }
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):
         | 
| 38 | 
            +
                adapter_path = lora_map[adapter_choice]
         | 
| 39 | 
            +
                if adapter_path not in [None, "None"]:
         | 
| 40 | 
            +
                    adapter_path = f"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                prompts = [prompt]*samples
         | 
| 43 | 
            +
                infer_loader = get_validation_dataloader(prompts)
         | 
| 44 | 
            +
                network = get_lora_network(pipe.unet, adapter_path)["network"]
         | 
| 45 | 
            +
                pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
         | 
| 46 | 
            +
                                        height=512, width=512, scales=[1.0],
         | 
| 47 | 
            +
                                        save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,
         | 
| 48 | 
            +
                                        start_noise=-1, show=False, style_prompt="sks art", no_load=True,
         | 
| 49 | 
            +
                                        from_scratch=True)[0][1.0]
         | 
| 50 | 
            +
                return pred_images
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):
         | 
| 53 | 
            +
                infer_loader = get_validation_dataloader(prompts, image)
         | 
| 54 | 
            +
                network = get_lora_network(pipe.unet, adapter_path,"all_up")["network"]
         | 
| 55 | 
            +
                pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,
         | 
| 56 | 
            +
                                        height=512, width=512, scales=[0.,1.],
         | 
| 57 | 
            +
                                        save_dir=None, seed=seed,steps=20, guidance_scale=7.5,
         | 
| 58 | 
            +
                                        start_noise=start_noise, show=True, style_prompt="sks art", no_load=True,
         | 
| 59 | 
            +
                                        from_scratch=False)
         | 
| 60 | 
            +
                return pred_images
         | 
| 61 | 
            +
             | 
| 62 | 
            +
            # def infer(prompt, samples, steps, scale, seed):
         | 
| 63 | 
            +
            #     generator = torch.Generator(device=device).manual_seed(seed)
         | 
| 64 | 
            +
            #     images_list = pipe(  # type: ignore
         | 
| 65 | 
            +
            #         [prompt] * samples,
         | 
| 66 | 
            +
            #         num_inference_steps=steps,
         | 
| 67 | 
            +
            #         guidance_scale=scale,
         | 
| 68 | 
            +
            #         generator=generator,
         | 
| 69 | 
            +
            #     )
         | 
| 70 | 
            +
            #     images = []
         | 
| 71 | 
            +
            #     safe_image = Image.open(r"data/unsafe.png")
         | 
| 72 | 
            +
            #     print(images_list)
         | 
| 73 | 
            +
            #     for i, image in enumerate(images_list["images"]):  # type: ignore
         | 
| 74 | 
            +
            #         if images_list["nsfw_content_detected"][i]:  # type: ignore
         | 
| 75 | 
            +
            #             images.append(safe_image)
         | 
| 76 | 
            +
            #         else:
         | 
| 77 | 
            +
            #             images.append(image)
         | 
| 78 | 
            +
            #     return images
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            block = gr.Blocks()
         | 
| 84 | 
            +
            # Direct infer
         | 
| 85 | 
            +
            with block:
         | 
| 86 | 
            +
                with gr.Group():
         | 
| 87 | 
            +
                    with gr.Row():
         | 
| 88 | 
            +
                        text = gr.Textbox(
         | 
| 89 | 
            +
                            label="Enter your prompt",
         | 
| 90 | 
            +
                            max_lines=2,
         | 
| 91 | 
            +
                            placeholder="Enter your prompt",
         | 
| 92 | 
            +
                            container=False,
         | 
| 93 | 
            +
                            value="Park with cherry blossom trees, picnicker’s and a clear blue pond.",
         | 
| 94 | 
            +
                        )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
             | 
| 97 | 
            +
             | 
| 98 | 
            +
                        btn = gr.Button("Run", scale=0)
         | 
| 99 | 
            +
                    gallery = gr.Gallery(
         | 
| 100 | 
            +
                        label="Generated images",
         | 
| 101 | 
            +
                        show_label=False,
         | 
| 102 | 
            +
                        elem_id="gallery",
         | 
| 103 | 
            +
                        columns=[2],
         | 
| 104 | 
            +
                    )
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    advanced_button = gr.Button("Advanced options", elem_id="advanced-btn")
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    with gr.Row(elem_id="advanced-options"):
         | 
| 109 | 
            +
                        adapter_choice = gr.Dropdown(
         | 
| 110 | 
            +
                            label="Choose adapter",
         | 
| 111 | 
            +
                            choices=["None", "Andre Derain","Vincent van Gogh","Andy Warhol", "Walter Battiss",
         | 
| 112 | 
            +
                                     "Camille Corot", "Claude Monet", "Pablo Picasso",
         | 
| 113 | 
            +
                                     "Jackson Pollock", "Gerhard Richter", "M.C. Escher",
         | 
| 114 | 
            +
                                     "Albert Gleizes", "Hokusai", "Wassily Kandinsky", "Gustav Klimt", "Roy Lichtenstein",
         | 
| 115 | 
            +
                                     "Henri Matisse", "Joan Miro"
         | 
| 116 | 
            +
                                     ],
         | 
| 117 | 
            +
                            value="None"
         | 
| 118 | 
            +
                        )
         | 
| 119 | 
            +
                        # print(adapter_choice[0])
         | 
| 120 | 
            +
                        # lora_path = lora_map[adapter_choice.value]
         | 
| 121 | 
            +
                        # if lora_path is not None:
         | 
| 122 | 
            +
                        #     lora_path = f"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt"
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                        samples = gr.Slider(label="Images", minimum=1, maximum=4, value=1, step=1)
         | 
| 125 | 
            +
                        steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
         | 
| 126 | 
            +
                        scale = gr.Slider(
         | 
| 127 | 
            +
                            label="Guidance Scale", minimum=0, maximum=50, value=7.5, step=0.1
         | 
| 128 | 
            +
                        )
         | 
| 129 | 
            +
                        print(scale)
         | 
| 130 | 
            +
                        seed = gr.Slider(
         | 
| 131 | 
            +
                            label="Seed",
         | 
| 132 | 
            +
                            minimum=0,
         | 
| 133 | 
            +
                            maximum=2147483647,
         | 
| 134 | 
            +
                            step=1,
         | 
| 135 | 
            +
                            randomize=True,
         | 
| 136 | 
            +
                        )
         | 
| 137 | 
            +
             | 
| 138 | 
            +
                    gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)
         | 
| 139 | 
            +
                    advanced_button.click(
         | 
| 140 | 
            +
                        None,
         | 
| 141 | 
            +
                        [],
         | 
| 142 | 
            +
                        text,
         | 
| 143 | 
            +
                    )
         | 
| 144 | 
            +
             | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            block.launch()
         | 
    	
        hf_demo_test.ipynb
    ADDED
    
    | @@ -0,0 +1,336 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            {
         | 
| 2 | 
            +
             "cells": [
         | 
| 3 | 
            +
              {
         | 
| 4 | 
            +
               "cell_type": "code",
         | 
| 5 | 
            +
               "execution_count": 1,
         | 
| 6 | 
            +
               "id": "initial_id",
         | 
| 7 | 
            +
               "metadata": {
         | 
| 8 | 
            +
                "ExecuteTime": {
         | 
| 9 | 
            +
                 "end_time": "2024-12-09T09:44:30.641366Z",
         | 
| 10 | 
            +
                 "start_time": "2024-12-09T09:44:11.789050Z"
         | 
| 11 | 
            +
                }
         | 
| 12 | 
            +
               },
         | 
| 13 | 
            +
               "outputs": [],
         | 
| 14 | 
            +
               "source": [
         | 
| 15 | 
            +
                "import os\n",
         | 
| 16 | 
            +
                "\n",
         | 
| 17 | 
            +
                "import gradio as gr\n",
         | 
| 18 | 
            +
                "from diffusers import DiffusionPipeline\n",
         | 
| 19 | 
            +
                "import matplotlib.pyplot as plt\n",
         | 
| 20 | 
            +
                "import torch\n",
         | 
| 21 | 
            +
                "from PIL import Image\n"
         | 
| 22 | 
            +
               ]
         | 
| 23 | 
            +
              },
         | 
| 24 | 
            +
              {
         | 
| 25 | 
            +
               "cell_type": "code",
         | 
| 26 | 
            +
               "execution_count": 2,
         | 
| 27 | 
            +
               "id": "ddf33e0d3abacc2c",
         | 
| 28 | 
            +
               "metadata": {},
         | 
| 29 | 
            +
               "outputs": [],
         | 
| 30 | 
            +
               "source": [
         | 
| 31 | 
            +
                "import sys\n",
         | 
| 32 | 
            +
                "#append current path\n",
         | 
| 33 | 
            +
                "sys.path.extend(\"/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/release/hf_demo\")"
         | 
| 34 | 
            +
               ]
         | 
| 35 | 
            +
              },
         | 
| 36 | 
            +
              {
         | 
| 37 | 
            +
               "cell_type": "code",
         | 
| 38 | 
            +
               "execution_count": 3,
         | 
| 39 | 
            +
               "id": "643e49fd601daf8f",
         | 
| 40 | 
            +
               "metadata": {
         | 
| 41 | 
            +
                "ExecuteTime": {
         | 
| 42 | 
            +
                 "end_time": "2024-12-09T09:44:35.790962Z",
         | 
| 43 | 
            +
                 "start_time": "2024-12-09T09:44:35.779496Z"
         | 
| 44 | 
            +
                }
         | 
| 45 | 
            +
               },
         | 
| 46 | 
            +
               "outputs": [],
         | 
| 47 | 
            +
               "source": [
         | 
| 48 | 
            +
                "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\""
         | 
| 49 | 
            +
               ]
         | 
| 50 | 
            +
              },
         | 
| 51 | 
            +
              {
         | 
| 52 | 
            +
               "cell_type": "code",
         | 
| 53 | 
            +
               "execution_count": 4,
         | 
| 54 | 
            +
               "id": "e03aae2a4e5676dd",
         | 
| 55 | 
            +
               "metadata": {
         | 
| 56 | 
            +
                "ExecuteTime": {
         | 
| 57 | 
            +
                 "end_time": "2024-12-09T09:44:44.157412Z",
         | 
| 58 | 
            +
                 "start_time": "2024-12-09T09:44:37.138452Z"
         | 
| 59 | 
            +
                }
         | 
| 60 | 
            +
               },
         | 
| 61 | 
            +
               "outputs": [
         | 
| 62 | 
            +
                {
         | 
| 63 | 
            +
                 "name": "stderr",
         | 
| 64 | 
            +
                 "output_type": "stream",
         | 
| 65 | 
            +
                 "text": [
         | 
| 66 | 
            +
                  "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
         | 
| 67 | 
            +
                  "  warnings.warn(\n"
         | 
| 68 | 
            +
                 ]
         | 
| 69 | 
            +
                },
         | 
| 70 | 
            +
                {
         | 
| 71 | 
            +
                 "data": {
         | 
| 72 | 
            +
                  "application/vnd.jupyter.widget-view+json": {
         | 
| 73 | 
            +
                   "model_id": "9df8347307674ba8afb0250e23109aa1",
         | 
| 74 | 
            +
                   "version_major": 2,
         | 
| 75 | 
            +
                   "version_minor": 0
         | 
| 76 | 
            +
                  },
         | 
| 77 | 
            +
                  "text/plain": [
         | 
| 78 | 
            +
                   "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
         | 
| 79 | 
            +
                  ]
         | 
| 80 | 
            +
                 },
         | 
| 81 | 
            +
                 "metadata": {},
         | 
| 82 | 
            +
                 "output_type": "display_data"
         | 
| 83 | 
            +
                }
         | 
| 84 | 
            +
               ],
         | 
| 85 | 
            +
               "source": [
         | 
| 86 | 
            +
                "pipe = DiffusionPipeline.from_pretrained(\"rhfeiyang/art-free-diffusion-v1\",).to(\"cuda\")\n",
         | 
| 87 | 
            +
                "device = \"cuda\""
         | 
| 88 | 
            +
               ]
         | 
| 89 | 
            +
              },
         | 
| 90 | 
            +
              {
         | 
| 91 | 
            +
               "cell_type": "code",
         | 
| 92 | 
            +
               "execution_count": 5,
         | 
| 93 | 
            +
               "id": "83916bc68ff5d914",
         | 
| 94 | 
            +
               "metadata": {
         | 
| 95 | 
            +
                "ExecuteTime": {
         | 
| 96 | 
            +
                 "end_time": "2024-12-09T09:44:52.694399Z",
         | 
| 97 | 
            +
                 "start_time": "2024-12-09T09:44:44.210695Z"
         | 
| 98 | 
            +
                }
         | 
| 99 | 
            +
               },
         | 
| 100 | 
            +
               "outputs": [],
         | 
| 101 | 
            +
               "source": [
         | 
| 102 | 
            +
                "from inference import get_lora_network, inference, get_validation_dataloader\n",
         | 
| 103 | 
            +
                "lora_map = {\n",
         | 
| 104 | 
            +
                "    \"None\": \"None\",\n",
         | 
| 105 | 
            +
                "    \"Andre Derain\": \"andre-derain_subset1\",\n",
         | 
| 106 | 
            +
                "    \"Vincent van Gogh\": \"van_gogh_subset1\",\n",
         | 
| 107 | 
            +
                "    \"Andy Warhol\": \"andy_subset1\",\n",
         | 
| 108 | 
            +
                "    \"Walter Battiss\": \"walter-battiss_subset2\",\n",
         | 
| 109 | 
            +
                "    \"Camille Corot\": \"camille-corot_subset1\",\n",
         | 
| 110 | 
            +
                "    \"Claude Monet\": \"monet_subset2\",\n",
         | 
| 111 | 
            +
                "    \"Pablo Picasso\": \"picasso_subset1\",\n",
         | 
| 112 | 
            +
                "    \"Jackson Pollock\": \"jackson-pollock_subset1\",\n",
         | 
| 113 | 
            +
                "    \"Gerhard Richter\": \"gerhard-richter_subset1\",\n",
         | 
| 114 | 
            +
                "    \"M.C. Escher\": \"m.c.-escher_subset1\",\n",
         | 
| 115 | 
            +
                "    \"Albert Gleizes\": \"albert-gleizes_subset1\",\n",
         | 
| 116 | 
            +
                "    \"Hokusai\": \"katsushika-hokusai_subset1\",\n",
         | 
| 117 | 
            +
                "    \"Wassily Kandinsky\": \"kandinsky_subset1\",\n",
         | 
| 118 | 
            +
                "    \"Gustav Klimt\": \"klimt_subset3\",\n",
         | 
| 119 | 
            +
                "    \"Roy Lichtenstein\": \"roy-lichtenstein_subset1\",\n",
         | 
| 120 | 
            +
                "    \"Henri Matisse\": \"henri-matisse_subset1\",\n",
         | 
| 121 | 
            +
                "    \"Joan Miro\": \"joan-miro_subset2\",\n",
         | 
| 122 | 
            +
                "}\n",
         | 
| 123 | 
            +
                "\n",
         | 
| 124 | 
            +
                "def demo_inference_gen(adapter_choice:str, prompt:str, samples:int=1,seed:int=0, steps=50, guidance_scale=7.5):\n",
         | 
| 125 | 
            +
                "    adapter_path = lora_map[adapter_choice]\n",
         | 
| 126 | 
            +
                "    if adapter_path not in [None, \"None\"]:\n",
         | 
| 127 | 
            +
                "        adapter_path = f\"data/Art_adapters/{adapter_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
         | 
| 128 | 
            +
                "\n",
         | 
| 129 | 
            +
                "    prompts = [prompt]*samples\n",
         | 
| 130 | 
            +
                "    infer_loader = get_validation_dataloader(prompts)\n",
         | 
| 131 | 
            +
                "    network = get_lora_network(pipe.unet, adapter_path)[\"network\"]\n",
         | 
| 132 | 
            +
                "    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
         | 
| 133 | 
            +
                "                            height=512, width=512, scales=[1.0],\n",
         | 
| 134 | 
            +
                "                            save_dir=None, seed=seed,steps=steps, guidance_scale=guidance_scale,\n",
         | 
| 135 | 
            +
                "                            start_noise=-1, show=False, style_prompt=\"sks art\", no_load=True,\n",
         | 
| 136 | 
            +
                "                            from_scratch=True)[0][1.0]\n",
         | 
| 137 | 
            +
                "    return pred_images\n",
         | 
| 138 | 
            +
                "\n",
         | 
| 139 | 
            +
                "def demo_inference_stylization(adapter_path:str, prompts:list, image:list, start_noise=800,seed:int=0):\n",
         | 
| 140 | 
            +
                "    infer_loader = get_validation_dataloader(prompts, image)\n",
         | 
| 141 | 
            +
                "    network = get_lora_network(pipe.unet, adapter_path,\"all_up\")[\"network\"]\n",
         | 
| 142 | 
            +
                "    pred_images = inference(network, pipe.tokenizer, pipe.text_encoder, pipe.vae, pipe.unet, pipe.scheduler, infer_loader,\n",
         | 
| 143 | 
            +
                "                            height=512, width=512, scales=[0.,1.],\n",
         | 
| 144 | 
            +
                "                            save_dir=None, seed=seed,steps=20, guidance_scale=7.5,\n",
         | 
| 145 | 
            +
                "                            start_noise=start_noise, show=True, style_prompt=\"sks art\", no_load=True,\n",
         | 
| 146 | 
            +
                "                            from_scratch=False)\n",
         | 
| 147 | 
            +
                "    return pred_images\n",
         | 
| 148 | 
            +
                "\n",
         | 
| 149 | 
            +
                "# def infer(prompt, samples, steps, scale, seed):\n",
         | 
| 150 | 
            +
                "#     generator = torch.Generator(device=device).manual_seed(seed)\n",
         | 
| 151 | 
            +
                "#     images_list = pipe(  # type: ignore\n",
         | 
| 152 | 
            +
                "#         [prompt] * samples,\n",
         | 
| 153 | 
            +
                "#         num_inference_steps=steps,\n",
         | 
| 154 | 
            +
                "#         guidance_scale=scale,\n",
         | 
| 155 | 
            +
                "#         generator=generator,\n",
         | 
| 156 | 
            +
                "#     )\n",
         | 
| 157 | 
            +
                "#     images = []\n",
         | 
| 158 | 
            +
                "#     safe_image = Image.open(r\"data/unsafe.png\")\n",
         | 
| 159 | 
            +
                "#     print(images_list)\n",
         | 
| 160 | 
            +
                "#     for i, image in enumerate(images_list[\"images\"]):  # type: ignore\n",
         | 
| 161 | 
            +
                "#         if images_list[\"nsfw_content_detected\"][i]:  # type: ignore\n",
         | 
| 162 | 
            +
                "#             images.append(safe_image)\n",
         | 
| 163 | 
            +
                "#         else:\n",
         | 
| 164 | 
            +
                "#             images.append(image)\n",
         | 
| 165 | 
            +
                "#     return images\n"
         | 
| 166 | 
            +
               ]
         | 
| 167 | 
            +
              },
         | 
| 168 | 
            +
              {
         | 
| 169 | 
            +
               "cell_type": "code",
         | 
| 170 | 
            +
               "execution_count": 6,
         | 
| 171 | 
            +
               "id": "aa33e9d104023847",
         | 
| 172 | 
            +
               "metadata": {
         | 
| 173 | 
            +
                "ExecuteTime": {
         | 
| 174 | 
            +
                 "end_time": "2024-12-09T12:09:39.339583Z",
         | 
| 175 | 
            +
                 "start_time": "2024-12-09T12:09:38.953936Z"
         | 
| 176 | 
            +
                }
         | 
| 177 | 
            +
               },
         | 
| 178 | 
            +
               "outputs": [
         | 
| 179 | 
            +
                {
         | 
| 180 | 
            +
                 "name": "stdout",
         | 
| 181 | 
            +
                 "output_type": "stream",
         | 
| 182 | 
            +
                 "text": [
         | 
| 183 | 
            +
                  "<gradio.components.slider.Slider object at 0x7fa12d3a5280>\n",
         | 
| 184 | 
            +
                  "Running on local URL:  http://127.0.0.1:7876\n",
         | 
| 185 | 
            +
                  "Running on public URL: https://be7cce8fec75395c82.gradio.live\n",
         | 
| 186 | 
            +
                  "\n",
         | 
| 187 | 
            +
                  "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
         | 
| 188 | 
            +
                 ]
         | 
| 189 | 
            +
                },
         | 
| 190 | 
            +
                {
         | 
| 191 | 
            +
                 "data": {
         | 
| 192 | 
            +
                  "text/html": [
         | 
| 193 | 
            +
                   "<div><iframe src=\"https://be7cce8fec75395c82.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
         | 
| 194 | 
            +
                  ],
         | 
| 195 | 
            +
                  "text/plain": [
         | 
| 196 | 
            +
                   "<IPython.core.display.HTML object>"
         | 
| 197 | 
            +
                  ]
         | 
| 198 | 
            +
                 },
         | 
| 199 | 
            +
                 "metadata": {},
         | 
| 200 | 
            +
                 "output_type": "display_data"
         | 
| 201 | 
            +
                },
         | 
| 202 | 
            +
                {
         | 
| 203 | 
            +
                 "data": {
         | 
| 204 | 
            +
                  "text/plain": []
         | 
| 205 | 
            +
                 },
         | 
| 206 | 
            +
                 "execution_count": 6,
         | 
| 207 | 
            +
                 "metadata": {},
         | 
| 208 | 
            +
                 "output_type": "execute_result"
         | 
| 209 | 
            +
                },
         | 
| 210 | 
            +
                {
         | 
| 211 | 
            +
                 "name": "stdout",
         | 
| 212 | 
            +
                 "output_type": "stream",
         | 
| 213 | 
            +
                 "text": [
         | 
| 214 | 
            +
                  "Train method: None\n",
         | 
| 215 | 
            +
                  "Rank: 1, Alpha: 1\n",
         | 
| 216 | 
            +
                  "create LoRA for U-Net: 0 modules.\n",
         | 
| 217 | 
            +
                  "save dir: None\n",
         | 
| 218 | 
            +
                  "['Park with cherry blossom trees, picnicker’s and a clear blue pond in the style of sks art'], seed=949192390\n"
         | 
| 219 | 
            +
                 ]
         | 
| 220 | 
            +
                },
         | 
| 221 | 
            +
                {
         | 
| 222 | 
            +
                 "name": "stderr",
         | 
| 223 | 
            +
                 "output_type": "stream",
         | 
| 224 | 
            +
                 "text": [
         | 
| 225 | 
            +
                  "/data/vision/torralba/selfmanaged/torralba/scratch/jomat/sam_dataset/miniforge3/envs/diffusion/lib/python3.9/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608883701/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)\n",
         | 
| 226 | 
            +
                  "  return F.conv2d(input, weight, bias, self.stride,\n",
         | 
| 227 | 
            +
                  "\n",
         | 
| 228 | 
            +
                  "00%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [00:03<00:00,  6.90it/s]"
         | 
| 229 | 
            +
                 ]
         | 
| 230 | 
            +
                },
         | 
| 231 | 
            +
                {
         | 
| 232 | 
            +
                 "name": "stdout",
         | 
| 233 | 
            +
                 "output_type": "stream",
         | 
| 234 | 
            +
                 "text": [
         | 
| 235 | 
            +
                  "Time taken for one batch, Art Adapter scale=1.0: 3.2747044563293457\n"
         | 
| 236 | 
            +
                 ]
         | 
| 237 | 
            +
                }
         | 
| 238 | 
            +
               ],
         | 
| 239 | 
            +
               "source": [
         | 
| 240 | 
            +
                "block = gr.Blocks()\n",
         | 
| 241 | 
            +
                "# Direct infer\n",
         | 
| 242 | 
            +
                "with block:\n",
         | 
| 243 | 
            +
                "    with gr.Group():\n",
         | 
| 244 | 
            +
                "        with gr.Row():\n",
         | 
| 245 | 
            +
                "            text = gr.Textbox(\n",
         | 
| 246 | 
            +
                "                label=\"Enter your prompt\",\n",
         | 
| 247 | 
            +
                "                max_lines=2,\n",
         | 
| 248 | 
            +
                "                placeholder=\"Enter your prompt\",\n",
         | 
| 249 | 
            +
                "                container=False,\n",
         | 
| 250 | 
            +
                "                value=\"Park with cherry blossom trees, picnicker’s and a clear blue pond.\",\n",
         | 
| 251 | 
            +
                "            )\n",
         | 
| 252 | 
            +
                "            \n",
         | 
| 253 | 
            +
                "\n",
         | 
| 254 | 
            +
                "            \n",
         | 
| 255 | 
            +
                "            btn = gr.Button(\"Run\", scale=0)\n",
         | 
| 256 | 
            +
                "        gallery = gr.Gallery(\n",
         | 
| 257 | 
            +
                "            label=\"Generated images\",\n",
         | 
| 258 | 
            +
                "            show_label=False,\n",
         | 
| 259 | 
            +
                "            elem_id=\"gallery\",\n",
         | 
| 260 | 
            +
                "            columns=[2],\n",
         | 
| 261 | 
            +
                "        )\n",
         | 
| 262 | 
            +
                "\n",
         | 
| 263 | 
            +
                "        advanced_button = gr.Button(\"Advanced options\", elem_id=\"advanced-btn\")\n",
         | 
| 264 | 
            +
                "\n",
         | 
| 265 | 
            +
                "        with gr.Row(elem_id=\"advanced-options\"):\n",
         | 
| 266 | 
            +
                "            adapter_choice = gr.Dropdown(\n",
         | 
| 267 | 
            +
                "                label=\"Choose adapter\",\n",
         | 
| 268 | 
            +
                "                choices=[\"None\", \"Andre Derain\",\"Vincent van Gogh\",\"Andy Warhol\", \"Walter Battiss\",\n",
         | 
| 269 | 
            +
                "                         \"Camille Corot\", \"Claude Monet\", \"Pablo Picasso\",\n",
         | 
| 270 | 
            +
                "                         \"Jackson Pollock\", \"Gerhard Richter\", \"M.C. Escher\",\n",
         | 
| 271 | 
            +
                "                         \"Albert Gleizes\", \"Hokusai\", \"Wassily Kandinsky\", \"Gustav Klimt\", \"Roy Lichtenstein\",\n",
         | 
| 272 | 
            +
                "                         \"Henri Matisse\", \"Joan Miro\"\n",
         | 
| 273 | 
            +
                "                         ],\n",
         | 
| 274 | 
            +
                "                value=\"None\"\n",
         | 
| 275 | 
            +
                "            )\n",
         | 
| 276 | 
            +
                "            # print(adapter_choice[0])\n",
         | 
| 277 | 
            +
                "            # lora_path = lora_map[adapter_choice.value]\n",
         | 
| 278 | 
            +
                "            # if lora_path is not None:\n",
         | 
| 279 | 
            +
                "            #     lora_path = f\"data/Art_adapters/{lora_path}/adapter_alpha1.0_rank1_all_up_1000steps.pt\"\n",
         | 
| 280 | 
            +
                "\n",
         | 
| 281 | 
            +
                "            samples = gr.Slider(label=\"Images\", minimum=1, maximum=4, value=1, step=1)\n",
         | 
| 282 | 
            +
                "            steps = gr.Slider(label=\"Steps\", minimum=1, maximum=50, value=20, step=1)\n",
         | 
| 283 | 
            +
                "            scale = gr.Slider(\n",
         | 
| 284 | 
            +
                "                label=\"Guidance Scale\", minimum=0, maximum=50, value=7.5, step=0.1\n",
         | 
| 285 | 
            +
                "            )\n",
         | 
| 286 | 
            +
                "            print(scale)\n",
         | 
| 287 | 
            +
                "            seed = gr.Slider(\n",
         | 
| 288 | 
            +
                "                label=\"Seed\",\n",
         | 
| 289 | 
            +
                "                minimum=0,\n",
         | 
| 290 | 
            +
                "                maximum=2147483647,\n",
         | 
| 291 | 
            +
                "                step=1,\n",
         | 
| 292 | 
            +
                "                randomize=True,\n",
         | 
| 293 | 
            +
                "            )\n",
         | 
| 294 | 
            +
                "\n",
         | 
| 295 | 
            +
                "        gr.on([text.submit, btn.click], demo_inference_gen, inputs=[adapter_choice, text, samples, seed, steps, scale], outputs=gallery)\n",
         | 
| 296 | 
            +
                "        advanced_button.click(\n",
         | 
| 297 | 
            +
                "            None,\n",
         | 
| 298 | 
            +
                "            [],\n",
         | 
| 299 | 
            +
                "            text,\n",
         | 
| 300 | 
            +
                "        )\n",
         | 
| 301 | 
            +
                "\n",
         | 
| 302 | 
            +
                "\n",
         | 
| 303 | 
            +
                "block.launch(share=True)"
         | 
| 304 | 
            +
               ]
         | 
| 305 | 
            +
              },
         | 
| 306 | 
            +
              {
         | 
| 307 | 
            +
               "cell_type": "code",
         | 
| 308 | 
            +
               "execution_count": null,
         | 
| 309 | 
            +
               "id": "3239c12167a5f2cd",
         | 
| 310 | 
            +
               "metadata": {},
         | 
| 311 | 
            +
               "outputs": [],
         | 
| 312 | 
            +
               "source": []
         | 
| 313 | 
            +
              }
         | 
| 314 | 
            +
             ],
         | 
| 315 | 
            +
             "metadata": {
         | 
| 316 | 
            +
              "kernelspec": {
         | 
| 317 | 
            +
               "display_name": "Python 3 (ipykernel)",
         | 
| 318 | 
            +
               "language": "python",
         | 
| 319 | 
            +
               "name": "python3"
         | 
| 320 | 
            +
              },
         | 
| 321 | 
            +
              "language_info": {
         | 
| 322 | 
            +
               "codemirror_mode": {
         | 
| 323 | 
            +
                "name": "ipython",
         | 
| 324 | 
            +
                "version": 3
         | 
| 325 | 
            +
               },
         | 
| 326 | 
            +
               "file_extension": ".py",
         | 
| 327 | 
            +
               "mimetype": "text/x-python",
         | 
| 328 | 
            +
               "name": "python",
         | 
| 329 | 
            +
               "nbconvert_exporter": "python",
         | 
| 330 | 
            +
               "pygments_lexer": "ipython3",
         | 
| 331 | 
            +
               "version": "3.9.18"
         | 
| 332 | 
            +
              }
         | 
| 333 | 
            +
             },
         | 
| 334 | 
            +
             "nbformat": 4,
         | 
| 335 | 
            +
             "nbformat_minor": 5
         | 
| 336 | 
            +
            }
         | 
    	
        inference.py
    ADDED
    
    | @@ -0,0 +1,657 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
            import torch
         | 
| 3 | 
            +
            from PIL import Image
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
            import os, json, random
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import matplotlib.pyplot as plt
         | 
| 8 | 
            +
            import glob, re
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from tqdm import tqdm
         | 
| 11 | 
            +
            import numpy as np
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            import sys
         | 
| 14 | 
            +
            import gc
         | 
| 15 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # import train_util
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            from utils.train_util import get_noisy_image, encode_prompts
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler, DDIMScheduler, PNDMScheduler
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            from typing import Any, Dict, List, Optional, Tuple, Union
         | 
| 24 | 
            +
            from utils.lora import LoRANetwork, DEFAULT_TARGET_REPLACE, UNET_TARGET_REPLACE_MODULE_CONV
         | 
| 25 | 
            +
            import argparse
         | 
| 26 | 
            +
            # from diffusers.training_utils import EMAModel
         | 
| 27 | 
            +
            import shutil
         | 
| 28 | 
            +
            import yaml
         | 
| 29 | 
            +
            from easydict import EasyDict
         | 
| 30 | 
            +
            from utils.metrics import StyleContentMetric
         | 
| 31 | 
            +
            from torchvision import transforms
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            from custom_datasets.coco import CustomCocoCaptions
         | 
| 34 | 
            +
            from custom_datasets.imagepair import ImageSet
         | 
| 35 | 
            +
            from custom_datasets import get_dataset
         | 
| 36 | 
            +
            # from stable_diffusion.utils.modules import get_diffusion_modules
         | 
| 37 | 
            +
            # from diffusers import StableDiffusionImg2ImgPipeline
         | 
| 38 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 39 | 
            +
            import pickle
         | 
| 40 | 
            +
            import time
         | 
| 41 | 
            +
            def flush():
         | 
| 42 | 
            +
                torch.cuda.empty_cache()
         | 
| 43 | 
            +
                gc.collect()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            def get_train_method(lora_weight):
         | 
| 46 | 
            +
                if lora_weight is None:
         | 
| 47 | 
            +
                    return 'None'
         | 
| 48 | 
            +
                if 'full' in lora_weight:
         | 
| 49 | 
            +
                    train_method = 'full'
         | 
| 50 | 
            +
                elif "down_1_up_2_attn" in lora_weight:
         | 
| 51 | 
            +
                    train_method = 'up_2_attn'
         | 
| 52 | 
            +
                    print(f"Using up_2_attn for {lora_weight}")
         | 
| 53 | 
            +
                elif "down_2_up_1_up_2_attn" in lora_weight:
         | 
| 54 | 
            +
                    train_method = 'down_2_up_2_attn'
         | 
| 55 | 
            +
                elif "down_2_up_2_attn" in lora_weight:
         | 
| 56 | 
            +
                    train_method = 'down_2_up_2_attn'
         | 
| 57 | 
            +
                elif "down_2_attn" in lora_weight:
         | 
| 58 | 
            +
                    train_method = 'down_2_attn'
         | 
| 59 | 
            +
                elif 'noxattn' in lora_weight:
         | 
| 60 | 
            +
                    train_method = 'noxattn'
         | 
| 61 | 
            +
                elif "xattn" in lora_weight:
         | 
| 62 | 
            +
                    train_method = 'xattn'
         | 
| 63 | 
            +
                elif  "attn" in lora_weight:
         | 
| 64 | 
            +
                    train_method = 'attn'
         | 
| 65 | 
            +
                elif "all_up" in lora_weight:
         | 
| 66 | 
            +
                    train_method = 'all_up'
         | 
| 67 | 
            +
                else:
         | 
| 68 | 
            +
                    train_method = 'None'
         | 
| 69 | 
            +
                return train_method
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            def get_validation_dataloader(infer_prompts:list[str]=None, infer_images :list[str]=None,resolution=512, batch_size=10, num_workers=4, val_set="laion_pop500"):
         | 
| 72 | 
            +
                data_transforms = transforms.Compose(
         | 
| 73 | 
            +
                    [
         | 
| 74 | 
            +
                        transforms.Resize(resolution),
         | 
| 75 | 
            +
                        transforms.CenterCrop(resolution),
         | 
| 76 | 
            +
                    ]
         | 
| 77 | 
            +
                )
         | 
| 78 | 
            +
                def preprocess(example):
         | 
| 79 | 
            +
                    ret={}
         | 
| 80 | 
            +
                    ret["image"] = data_transforms(example["image"]) if "image" in example else None
         | 
| 81 | 
            +
                    if "caption" in example:
         | 
| 82 | 
            +
                        if isinstance(example["caption"][0], list):
         | 
| 83 | 
            +
                            ret["caption"] = example["caption"][0][0]
         | 
| 84 | 
            +
                        else:
         | 
| 85 | 
            +
                            ret["caption"] = example["caption"][0]
         | 
| 86 | 
            +
                    if "seed" in example:
         | 
| 87 | 
            +
                        ret["seed"] = example["seed"]
         | 
| 88 | 
            +
                    if "id" in example:
         | 
| 89 | 
            +
                        ret["id"] = example["id"]
         | 
| 90 | 
            +
                    if "path" in example:
         | 
| 91 | 
            +
                        ret["path"] = example["path"]
         | 
| 92 | 
            +
                    return ret
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                def collate_fn(examples):
         | 
| 95 | 
            +
                    out = {}
         | 
| 96 | 
            +
                    if "image" in examples[0]:
         | 
| 97 | 
            +
                        pixel_values = [example["image"] for example in examples]
         | 
| 98 | 
            +
                        out["pixel_values"] = pixel_values
         | 
| 99 | 
            +
                    # notice: only take the first prompt for each image
         | 
| 100 | 
            +
                    if "caption" in examples[0]:
         | 
| 101 | 
            +
                        prompts = [example["caption"] for example in examples]
         | 
| 102 | 
            +
                        out["prompts"] = prompts
         | 
| 103 | 
            +
                    if "seed" in examples[0]:
         | 
| 104 | 
            +
                        seeds = [example["seed"] for example in examples]
         | 
| 105 | 
            +
                        out["seed"] = seeds
         | 
| 106 | 
            +
                    if "path" in examples[0]:
         | 
| 107 | 
            +
                        paths = [example["path"] for example in examples]
         | 
| 108 | 
            +
                        out["path"] = paths
         | 
| 109 | 
            +
                    return out
         | 
| 110 | 
            +
                if infer_prompts is None:
         | 
| 111 | 
            +
                    if val_set == "lhq500":
         | 
| 112 | 
            +
                        dataset = get_dataset("lhq_sub500", get_val=False)["train"]
         | 
| 113 | 
            +
                    elif val_set == "custom_coco100":
         | 
| 114 | 
            +
                        dataset = get_dataset("custom_coco100", get_val=False)["train"]
         | 
| 115 | 
            +
                    elif val_set == "custom_coco500":
         | 
| 116 | 
            +
                        dataset = get_dataset("custom_coco500", get_val=False)["train"]
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    elif os.path.isdir(val_set):
         | 
| 119 | 
            +
                        image_folder = os.path.join(val_set, "paintings")
         | 
| 120 | 
            +
                        caption_folder = os.path.join(val_set, "captions")
         | 
| 121 | 
            +
                        dataset = ImageSet(folder=image_folder, caption=caption_folder, keep_in_mem=True)
         | 
| 122 | 
            +
                    elif "custom_caption" in val_set:
         | 
| 123 | 
            +
                        from custom_datasets.custom_caption import Caption_set
         | 
| 124 | 
            +
                        name = val_set.replace("custom_caption_", "")
         | 
| 125 | 
            +
                        dataset = Caption_set(set_name = name)
         | 
| 126 | 
            +
                    elif val_set == "laion_pop500":
         | 
| 127 | 
            +
                        dataset = get_dataset("laion_pop500", get_val=False)["train"]
         | 
| 128 | 
            +
                    elif val_set == "laion_pop500_first_sentence":
         | 
| 129 | 
            +
                        dataset = get_dataset("laion_pop500_first_sentence", get_val=False)["train"]
         | 
| 130 | 
            +
                    else:
         | 
| 131 | 
            +
                        raise ValueError("Unknown dataset")
         | 
| 132 | 
            +
                    dataset.with_transform(preprocess)
         | 
| 133 | 
            +
                elif isinstance(infer_prompts, torch.utils.data.Dataset):
         | 
| 134 | 
            +
                    dataset = infer_prompts
         | 
| 135 | 
            +
                    try:
         | 
| 136 | 
            +
                        dataset.with_transform(preprocess)
         | 
| 137 | 
            +
                    except:
         | 
| 138 | 
            +
                        pass
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                else:
         | 
| 141 | 
            +
                    class Dataset(torch.utils.data.Dataset):
         | 
| 142 | 
            +
                        def __init__(self, prompts, images=None):
         | 
| 143 | 
            +
                            self.prompts = prompts
         | 
| 144 | 
            +
                            self.images = images
         | 
| 145 | 
            +
                            self.get_img = False
         | 
| 146 | 
            +
                            if images is not None:
         | 
| 147 | 
            +
                                assert len(prompts) == len(images)
         | 
| 148 | 
            +
                                self.get_img = True
         | 
| 149 | 
            +
                                if isinstance(images[0], str):
         | 
| 150 | 
            +
                                    self.images = [Image.open(image).convert("RGB") for image in images]
         | 
| 151 | 
            +
                            else:
         | 
| 152 | 
            +
                                self.images = [None] * len(prompts)
         | 
| 153 | 
            +
                        def __len__(self):
         | 
| 154 | 
            +
                            return len(self.prompts)
         | 
| 155 | 
            +
                        def __getitem__(self, idx):
         | 
| 156 | 
            +
                            img = self.images[idx]
         | 
| 157 | 
            +
                            if self.get_img and img is not None:
         | 
| 158 | 
            +
                                img = data_transforms(img)
         | 
| 159 | 
            +
                            return {"caption": self.prompts[idx], "image":img}
         | 
| 160 | 
            +
                    dataset = Dataset(infer_prompts, infer_images)
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, drop_last=False,
         | 
| 163 | 
            +
                                                         num_workers=num_workers, pin_memory=True)
         | 
| 164 | 
            +
                return dataloader
         | 
| 165 | 
            +
             | 
| 166 | 
            +
            def get_lora_network(unet , lora_path, train_method="None", rank=1, alpha=1.0, device="cuda", weight_dtype=torch.float32):
         | 
| 167 | 
            +
                if train_method in [None, "None"]:
         | 
| 168 | 
            +
                    train_method = get_train_method(lora_path)
         | 
| 169 | 
            +
                    print(f"Train method: {train_method}")
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                network_type = "c3lier"
         | 
| 172 | 
            +
                if train_method == 'xattn':
         | 
| 173 | 
            +
                    network_type = 'lierla'
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                modules = DEFAULT_TARGET_REPLACE
         | 
| 176 | 
            +
                if network_type == "c3lier":
         | 
| 177 | 
            +
                    modules += UNET_TARGET_REPLACE_MODULE_CONV
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                alpha = 1
         | 
| 180 | 
            +
                if "rank" in lora_path:
         | 
| 181 | 
            +
                    rank = int(re.search(r'rank(\d+)', lora_path).group(1))
         | 
| 182 | 
            +
                if 'alpha1' in lora_path:
         | 
| 183 | 
            +
                    alpha = 1.0
         | 
| 184 | 
            +
                print(f"Rank: {rank}, Alpha: {alpha}")
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                network = LoRANetwork(
         | 
| 187 | 
            +
                    unet,
         | 
| 188 | 
            +
                    rank=rank,
         | 
| 189 | 
            +
                    multiplier=1.0,
         | 
| 190 | 
            +
                    alpha=alpha,
         | 
| 191 | 
            +
                    train_method=train_method,
         | 
| 192 | 
            +
                ).to(device, dtype=weight_dtype)
         | 
| 193 | 
            +
                if lora_path not in [None, "None"]:
         | 
| 194 | 
            +
                    lora_state_dict = torch.load(lora_path)
         | 
| 195 | 
            +
                    miss = network.load_state_dict(lora_state_dict, strict=False)
         | 
| 196 | 
            +
                    print(f"Missing: {miss}")
         | 
| 197 | 
            +
                ret = {"network": network, "train_method": train_method}
         | 
| 198 | 
            +
                return ret
         | 
| 199 | 
            +
             | 
| 200 | 
            +
            def get_model(pretrained_ckpt_path, unet_ckpt=None,revision=None, variant=None, lora_path=None, weight_dtype=torch.float32,
         | 
| 201 | 
            +
                          device="cuda"):
         | 
| 202 | 
            +
                modules = {}
         | 
| 203 | 
            +
                pipe = DiffusionPipeline.from_pretrained(pretrained_ckpt_path, revision=revision, variant=variant)
         | 
| 204 | 
            +
                if unet_ckpt is not None:
         | 
| 205 | 
            +
                    pipe.unet.from_pretrained(unet_ckpt, subfolder="unet_ema", revision=revision, variant=variant)
         | 
| 206 | 
            +
                unet = pipe.unet
         | 
| 207 | 
            +
                vae = pipe.vae
         | 
| 208 | 
            +
                text_encoder = pipe.text_encoder
         | 
| 209 | 
            +
                tokenizer = pipe.tokenizer
         | 
| 210 | 
            +
                modules["unet"] = unet
         | 
| 211 | 
            +
                modules["vae"] = vae
         | 
| 212 | 
            +
                modules["text_encoder"] = text_encoder
         | 
| 213 | 
            +
                modules["tokenizer"] = tokenizer
         | 
| 214 | 
            +
                # tokenizer = modules["tokenizer"]
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                unet.enable_xformers_memory_efficient_attention()
         | 
| 217 | 
            +
                unet.to(device, dtype=weight_dtype)
         | 
| 218 | 
            +
                if weight_dtype != torch.bfloat16:
         | 
| 219 | 
            +
                    vae.to(device, dtype=torch.float32)
         | 
| 220 | 
            +
                else:
         | 
| 221 | 
            +
                    vae.to(device, dtype=weight_dtype)
         | 
| 222 | 
            +
                text_encoder.to(device, dtype=weight_dtype)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                if lora_path is not None:
         | 
| 225 | 
            +
                    network = get_lora_network(unet, lora_path, device=device, weight_dtype=weight_dtype)
         | 
| 226 | 
            +
                    modules["network"] = network
         | 
| 227 | 
            +
                return modules
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
             | 
| 231 | 
            +
            @torch.no_grad()
         | 
| 232 | 
            +
            def inference(network: LoRANetwork, tokenizer: CLIPTokenizer, text_encoder: CLIPTextModel, vae: AutoencoderKL, unet: UNet2DConditionModel, noise_scheduler: LMSDiscreteScheduler,
         | 
| 233 | 
            +
                          dataloader, height:int, width:int, scales:list = np.linspace(0,2,5),save_dir:str=None, seed:int = None,
         | 
| 234 | 
            +
                          weight_dtype: torch.dtype = torch.float32, device: torch.device="cuda", batch_size:int=1, steps:int=50, guidance_scale:float=7.5, start_noise:int=800,
         | 
| 235 | 
            +
                          uncond_prompt:str=None, uncond_embed=None, style_prompt = None, show:bool = False, no_load:bool=False, from_scratch=False):
         | 
| 236 | 
            +
                print(f"save dir: {save_dir}")
         | 
| 237 | 
            +
                if start_noise < 0:
         | 
| 238 | 
            +
                    assert from_scratch
         | 
| 239 | 
            +
                network = network.eval()
         | 
| 240 | 
            +
                unet = unet.eval()
         | 
| 241 | 
            +
                vae = vae.eval()
         | 
| 242 | 
            +
                do_convert = not from_scratch
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                if not do_convert:
         | 
| 245 | 
            +
                    try:
         | 
| 246 | 
            +
                        dataloader.dataset.get_img = False
         | 
| 247 | 
            +
                    except:
         | 
| 248 | 
            +
                        pass
         | 
| 249 | 
            +
                    scales = list(scales)
         | 
| 250 | 
            +
                else:
         | 
| 251 | 
            +
                    scales = ["Real Image"] + list(scales)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                if not no_load and os.path.exists(os.path.join(save_dir, "infer_imgs.pickle")):
         | 
| 254 | 
            +
                    with open(os.path.join(save_dir, "infer_imgs.pickle"), 'rb') as f:
         | 
| 255 | 
            +
                        pred_images = pickle.load(f)
         | 
| 256 | 
            +
                    take=True
         | 
| 257 | 
            +
                    for key in scales:
         | 
| 258 | 
            +
                        if key not in pred_images:
         | 
| 259 | 
            +
                            take=False
         | 
| 260 | 
            +
                            break
         | 
| 261 | 
            +
                    if take:
         | 
| 262 | 
            +
                        print(f"Found existing inference results in {save_dir}", flush=True)
         | 
| 263 | 
            +
                        return pred_images
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                max_length = tokenizer.model_max_length
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                pred_images = {scale :[] for scale in scales}
         | 
| 268 | 
            +
                all_seeds = {scale:[] for scale in scales}
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                prompts = []
         | 
| 271 | 
            +
                ori_prompts = []
         | 
| 272 | 
            +
                if save_dir is not None:
         | 
| 273 | 
            +
                    img_output_dir = os.path.join(save_dir, "outputs")
         | 
| 274 | 
            +
                    os.makedirs(img_output_dir, exist_ok=True)
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                if uncond_embed is None:
         | 
| 277 | 
            +
                    if uncond_prompt is None:
         | 
| 278 | 
            +
                        uncond_input_text = [""]
         | 
| 279 | 
            +
                    else:
         | 
| 280 | 
            +
                        uncond_input_text = [uncond_prompt]
         | 
| 281 | 
            +
                    uncond_embed = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = uncond_input_text)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
             | 
| 284 | 
            +
                for batch in dataloader:
         | 
| 285 | 
            +
                    ori_prompt = batch["prompts"]
         | 
| 286 | 
            +
                    image = batch["pixel_values"] if do_convert else None
         | 
| 287 | 
            +
                    if do_convert:
         | 
| 288 | 
            +
                        pred_images["Real Image"] += image
         | 
| 289 | 
            +
                    if isinstance(ori_prompt, list):
         | 
| 290 | 
            +
                        if isinstance(text_encoder, CLIPTextModel):
         | 
| 291 | 
            +
                            # trunc prompts for clip encoder
         | 
| 292 | 
            +
                            ori_prompt = [p.split(".")[0]+"." for p in ori_prompt]
         | 
| 293 | 
            +
                        prompt = [f"{p.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" for p in ori_prompt] if style_prompt is not None else ori_prompt
         | 
| 294 | 
            +
                    else:
         | 
| 295 | 
            +
                        if isinstance(text_encoder, CLIPTextModel):
         | 
| 296 | 
            +
                            ori_prompt = ori_prompt.split(".")[0]+"."
         | 
| 297 | 
            +
                        prompt = f"{prompt.strip()[::-1].replace('.', '',1)[::-1]} in the style of {style_prompt}" if style_prompt is not None else ori_prompt
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    bcz = len(prompt)
         | 
| 300 | 
            +
                    single_seed = seed
         | 
| 301 | 
            +
                    if dataloader.batch_size == 1 and seed is None:
         | 
| 302 | 
            +
                        if "seed" in batch:
         | 
| 303 | 
            +
                            single_seed = batch["seed"][0]
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    print(f"{prompt}, seed={single_seed}")
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    # text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
         | 
| 308 | 
            +
                    # original_embeddings = text_encoder(**text_input)[0]
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                    prompts += prompt
         | 
| 311 | 
            +
                    ori_prompts += ori_prompt
         | 
| 312 | 
            +
                    # style_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt").to(device)
         | 
| 313 | 
            +
                    # # style_embeddings = text_encoder(**style_input)[0]
         | 
| 314 | 
            +
                    # style_embeddings = text_encoder(style_input.input_ids, return_dict=False)[0]
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                    style_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = prompt)
         | 
| 317 | 
            +
                    original_embeddings = encode_prompts(tokenizer = tokenizer, text_encoder = text_encoder, prompts = ori_prompt)
         | 
| 318 | 
            +
                    if uncond_embed.shape[0] == 1 and bcz > 1:
         | 
| 319 | 
            +
                        uncond_embeddings = uncond_embed.repeat(bcz, 1, 1)
         | 
| 320 | 
            +
                    else:
         | 
| 321 | 
            +
                        uncond_embeddings = uncond_embed
         | 
| 322 | 
            +
                    style_text_embeddings = torch.cat([uncond_embeddings, style_embeddings])
         | 
| 323 | 
            +
                    original_embeddings = torch.cat([uncond_embeddings, original_embeddings])
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    generator = torch.manual_seed(single_seed) if single_seed is not None else None
         | 
| 326 | 
            +
                    noise_scheduler.set_timesteps(steps)
         | 
| 327 | 
            +
                    if do_convert:
         | 
| 328 | 
            +
                        noised_latent, _, _ = get_noisy_image(image, vae, generator, unet, noise_scheduler, total_timesteps=int((1000-start_noise)/1000 *steps))
         | 
| 329 | 
            +
                    else:
         | 
| 330 | 
            +
                        latent_shape =  (bcz, 4, height//8, width//8)
         | 
| 331 | 
            +
                        noised_latent = randn_tensor(latent_shape, generator=generator, device=vae.device)
         | 
| 332 | 
            +
                    noised_latent = noised_latent.to(unet.dtype)
         | 
| 333 | 
            +
                    noised_latent = noised_latent * noise_scheduler.init_noise_sigma
         | 
| 334 | 
            +
                    for scale in scales:
         | 
| 335 | 
            +
                        start_time = time.time()
         | 
| 336 | 
            +
                        if not isinstance(scale, float) and not isinstance(scale, int):
         | 
| 337 | 
            +
                            continue
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                        latents = noised_latent.clone().to(weight_dtype).to(device)
         | 
| 340 | 
            +
                        noise_scheduler.set_timesteps(steps)
         | 
| 341 | 
            +
                        for t in tqdm(noise_scheduler.timesteps):
         | 
| 342 | 
            +
                            if do_convert and t>start_noise:
         | 
| 343 | 
            +
                                continue
         | 
| 344 | 
            +
                            else:
         | 
| 345 | 
            +
                                if t > start_noise and start_noise >= 0:
         | 
| 346 | 
            +
                                    current_scale = 0
         | 
| 347 | 
            +
                                else:
         | 
| 348 | 
            +
                                    current_scale = scale
         | 
| 349 | 
            +
                            network.set_lora_slider(scale=current_scale)
         | 
| 350 | 
            +
                            text_embedding = style_text_embeddings
         | 
| 351 | 
            +
                            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
         | 
| 352 | 
            +
                            latent_model_input = torch.cat([latents] * 2)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                            latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
         | 
| 355 | 
            +
                            # predict the noise residual
         | 
| 356 | 
            +
                            with network:
         | 
| 357 | 
            +
                                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embedding).sample
         | 
| 358 | 
            +
             | 
| 359 | 
            +
                            # perform guidance
         | 
| 360 | 
            +
                            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 361 | 
            +
                            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                            # compute the previous noisy sample x_t -> x_t-1
         | 
| 364 | 
            +
                            if isinstance(noise_scheduler, DDPMScheduler):
         | 
| 365 | 
            +
                                latents = noise_scheduler.step(noise_pred, t, latents, generator=torch.manual_seed(single_seed+t) if single_seed is not None else None).prev_sample
         | 
| 366 | 
            +
                            else:
         | 
| 367 | 
            +
                                latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
         | 
| 368 | 
            +
             | 
| 369 | 
            +
                        # scale and decode the image latents with vae
         | 
| 370 | 
            +
                        latents = 1 / 0.18215 * latents.to(vae.dtype)
         | 
| 371 | 
            +
             | 
| 372 | 
            +
             | 
| 373 | 
            +
                        with torch.no_grad():
         | 
| 374 | 
            +
                            image = vae.decode(latents).sample
         | 
| 375 | 
            +
                        image = (image / 2 + 0.5).clamp(0, 1)
         | 
| 376 | 
            +
                        image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
         | 
| 377 | 
            +
                        images = (image * 255).round().astype("uint8")
         | 
| 378 | 
            +
             | 
| 379 | 
            +
             | 
| 380 | 
            +
                        pil_images = [Image.fromarray(image) for image in images]
         | 
| 381 | 
            +
                        pred_images[scale]+=pil_images
         | 
| 382 | 
            +
                        all_seeds[scale] += [single_seed] * bcz
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                        end_time = time.time()
         | 
| 385 | 
            +
                        print(f"Time taken for one batch, Art Adapter scale={scale}: {end_time-start_time}", flush=True)
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    if save_dir is not None or show:
         | 
| 388 | 
            +
                        end_idx = len(list(pred_images.values())[0])
         | 
| 389 | 
            +
                        for i in range(end_idx-bcz, end_idx):
         | 
| 390 | 
            +
                            keys = list(pred_images.keys())
         | 
| 391 | 
            +
                            images_list = [pred_images[key][i] for key in keys]
         | 
| 392 | 
            +
                            prompt = prompts[i]
         | 
| 393 | 
            +
                            if len(scales)==1:
         | 
| 394 | 
            +
                                plt.imshow(images_list[0])
         | 
| 395 | 
            +
                                plt.axis('off')
         | 
| 396 | 
            +
                                plt.title(f"{prompt}_{single_seed}_start{start_noise}", fontsize=20)
         | 
| 397 | 
            +
                            else:
         | 
| 398 | 
            +
                                fig, ax = plt.subplots(1, len(images_list), figsize=(len(scales)*5,6), layout="constrained")
         | 
| 399 | 
            +
                                for id, a in enumerate(ax):
         | 
| 400 | 
            +
                                    a.imshow(images_list[id])
         | 
| 401 | 
            +
                                    if isinstance(scales[id], float) or isinstance(scales[id], int):
         | 
| 402 | 
            +
                                        a.set_title(f"Art Adapter scale={scales[id]}", fontsize=20)
         | 
| 403 | 
            +
                                    else:
         | 
| 404 | 
            +
                                        a.set_title(f"{keys[id]}", fontsize=20)
         | 
| 405 | 
            +
                                    a.axis('off')
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                                # plt.suptitle(f"{os.path.basename(lora_weight).replace('.pt','')}", fontsize=20)
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                                # plt.tight_layout()
         | 
| 410 | 
            +
                                # if do_convert:
         | 
| 411 | 
            +
                                #     plt.suptitle(f"{prompt}\nseed{single_seed}_start{start_noise}_guidance{guidance_scale}", fontsize=20)
         | 
| 412 | 
            +
                                # else:
         | 
| 413 | 
            +
                                #     plt.suptitle(f"{prompt}\nseed{single_seed}_from_scratch_guidance{guidance_scale}", fontsize=20)
         | 
| 414 | 
            +
             | 
| 415 | 
            +
                            if save_dir is not None:
         | 
| 416 | 
            +
                                plt.savefig(f"{img_output_dir}/{prompt.replace(' ', '_')[:100]}_seed{single_seed}_start{start_noise}.png")
         | 
| 417 | 
            +
                            if show:
         | 
| 418 | 
            +
                                plt.show()
         | 
| 419 | 
            +
                            plt.close()
         | 
| 420 | 
            +
             | 
| 421 | 
            +
                    flush()
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                if save_dir is not None:
         | 
| 424 | 
            +
                    with open(os.path.join(save_dir, "infer_imgs.pickle" ), 'wb') as f:
         | 
| 425 | 
            +
                        pickle.dump(pred_images, f)
         | 
| 426 | 
            +
                    with open(os.path.join(save_dir, "all_seeds.pickle"), 'wb') as f:
         | 
| 427 | 
            +
                        to_save={"all_seeds":all_seeds, "batch_size":batch_size}
         | 
| 428 | 
            +
                        pickle.dump(to_save, f)
         | 
| 429 | 
            +
                    for scale, images in pred_images.items():
         | 
| 430 | 
            +
                        subfolder = os.path.join(save_dir,"images", f"{scale}")
         | 
| 431 | 
            +
                        os.makedirs(subfolder, exist_ok=True)
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                        used_prompt = ori_prompts
         | 
| 434 | 
            +
                        if (isinstance(scale, float) or isinstance(scale, int)): #and scale != 0:
         | 
| 435 | 
            +
                            used_prompt = prompts
         | 
| 436 | 
            +
                        for i, image in enumerate(images):
         | 
| 437 | 
            +
                            if scale == "Real Image":
         | 
| 438 | 
            +
                                suffix = ""
         | 
| 439 | 
            +
                            else:
         | 
| 440 | 
            +
                                suffix = f"_seed{all_seeds[scale][i]}"
         | 
| 441 | 
            +
                            image.save(os.path.join(subfolder, f"{used_prompt[i].replace(' ', '_')[:100]}{suffix}.jpg"))
         | 
| 442 | 
            +
                    with open(os.path.join(save_dir, "infer_prompts.txt"), 'w') as f:
         | 
| 443 | 
            +
                        for prompt in prompts:
         | 
| 444 | 
            +
                            f.write(f"{prompt}\n")
         | 
| 445 | 
            +
                    with open(os.path.join(save_dir, "ori_prompts.txt"), 'w') as f:
         | 
| 446 | 
            +
                        for prompt in ori_prompts:
         | 
| 447 | 
            +
                            f.write(f"{prompt}\n")
         | 
| 448 | 
            +
                    print(f"Saved inference results to {save_dir}", flush=True)
         | 
| 449 | 
            +
                return pred_images, prompts
         | 
| 450 | 
            +
             | 
| 451 | 
            +
            @torch.no_grad()
         | 
| 452 | 
            +
            def infer_metric(ref_image_folder,pred_images, prompts, save_dir, start_noise=""):
         | 
| 453 | 
            +
                prompts = [prompt.split(" in the style of ")[0] for prompt in prompts]
         | 
| 454 | 
            +
                scores = {}
         | 
| 455 | 
            +
                original_images = pred_images["Real Image"] if "Real Image" in pred_images else None
         | 
| 456 | 
            +
                metric = StyleContentMetric(ref_image_folder)
         | 
| 457 | 
            +
                for scale, images in pred_images.items():
         | 
| 458 | 
            +
                    score = metric(images, original_images, prompts)
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                    scores[scale] = score
         | 
| 461 | 
            +
                    print(f"Style transfer score at scale {scale}: {score}")
         | 
| 462 | 
            +
                scores["ref_path"] = ref_image_folder
         | 
| 463 | 
            +
                save_name = f"scores_start{start_noise}.json"
         | 
| 464 | 
            +
                os.makedirs(save_dir, exist_ok=True)
         | 
| 465 | 
            +
                with open(os.path.join(save_dir, save_name), 'w') as f:
         | 
| 466 | 
            +
                    json.dump(scores, f, indent=2)
         | 
| 467 | 
            +
                return scores
         | 
| 468 | 
            +
             | 
| 469 | 
            +
            def parse_args():
         | 
| 470 | 
            +
                parser = argparse.ArgumentParser(description='Inference with LoRA')
         | 
| 471 | 
            +
                parser.add_argument('--lora_weights', type=str, default=["None"],
         | 
| 472 | 
            +
                                    nargs='+', help='path to your model file')
         | 
| 473 | 
            +
                parser.add_argument('--prompts', type=str, default=[],
         | 
| 474 | 
            +
                                    nargs='+', help='prompts to try')
         | 
| 475 | 
            +
                parser.add_argument("--prompt_file", type=str, default=None, help="path to the prompt file")
         | 
| 476 | 
            +
                parser.add_argument("--prompt_file_key", type=str, default="prompts", help="key to the prompt file")
         | 
| 477 | 
            +
                parser.add_argument('--resolution', type=int, default=512, help='resolution of the image')
         | 
| 478 | 
            +
                parser.add_argument('--seed', type=int, default=None, help='seed for the random number generator')
         | 
| 479 | 
            +
                parser.add_argument("--start_noise", type=int, default=800, help="start noise")
         | 
| 480 | 
            +
                parser.add_argument("--from_scratch", default=False, action="store_true", help="from scratch")
         | 
| 481 | 
            +
                parser.add_argument("--ref_image_folder", type=str, default=None, help="folder containing reference images")
         | 
| 482 | 
            +
                parser.add_argument("--show", action="store_true", help="show the image")
         | 
| 483 | 
            +
                parser.add_argument("--batch_size", type=int, default=1, help="batch size")
         | 
| 484 | 
            +
                parser.add_argument("--scales", type=float, default=[0.,1.], nargs='+', help="scales to test")
         | 
| 485 | 
            +
                parser.add_argument("--train_method", type=str, default=None, help="train method")
         | 
| 486 | 
            +
             | 
| 487 | 
            +
                # parser.add_argument("--vae_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the VAE model.")
         | 
| 488 | 
            +
                # parser.add_argument("--text_encoder_path", type=str, default="CompVis/stable-diffusion-v1-4", help="Path to the text encoder model.")
         | 
| 489 | 
            +
                parser.add_argument("--pretrained_model_name_or_path", type=str, default="rhfeiyang/art-free-diffusion-v1", help="Path to the pretrained model.")
         | 
| 490 | 
            +
                parser.add_argument("--unet_ckpt", default=None, type=str, help="Path to the unet checkpoint")
         | 
| 491 | 
            +
                parser.add_argument("--guidance_scale", type=float, default=5.0, help="guidance scale")
         | 
| 492 | 
            +
                parser.add_argument("--infer_mode", default="sks_art",  help="inference mode") #, choices=["style", "ori", "artist", "sks_art","Peter"]
         | 
| 493 | 
            +
                parser.add_argument("--save_dir", type=str, default="inference_output", help="save directory")
         | 
| 494 | 
            +
                parser.add_argument("--num_workers", type=int, default=4, help="number of workers")
         | 
| 495 | 
            +
                parser.add_argument("--no_load", action="store_true", help="no load the pre-inferred results")
         | 
| 496 | 
            +
                parser.add_argument("--infer_prompts", type=str, default=None, nargs="+", help="prompts to infer")
         | 
| 497 | 
            +
                parser.add_argument("--infer_images", type=str, default=None, nargs="+", help="images to infer")
         | 
| 498 | 
            +
                parser.add_argument("--rank", type=int, default=1, help="rank of the lora")
         | 
| 499 | 
            +
                parser.add_argument("--val_set", type=str, default="laion_pop500",  help="validation set")
         | 
| 500 | 
            +
                parser.add_argument("--folder_name", type=str, default=None, help="folder name")
         | 
| 501 | 
            +
                parser.add_argument("--scheduler_type",type=str, choices=["ddpm", "ddim", "pndm","lms"], default="ddpm", help="scheduler type")
         | 
| 502 | 
            +
                parser.add_argument("--infer_steps", type=int, default=50, help="inference steps")
         | 
| 503 | 
            +
                parser.add_argument("--weight_dtype", type=str, default="fp32", help="weight dtype")
         | 
| 504 | 
            +
                parser.add_argument("--custom_coco_cap", action="store_true", help="use custom coco caption")
         | 
| 505 | 
            +
                args = parser.parse_args()
         | 
| 506 | 
            +
                if args.infer_prompts is not None and len(args.infer_prompts) == 1 and os.path.isfile(args.infer_prompts[0]):
         | 
| 507 | 
            +
                    if args.infer_prompts[0].endswith(".txt") and args.custom_coco_cap:
         | 
| 508 | 
            +
                        args.infer_prompts = CustomCocoCaptions(custom_file=args.infer_prompts[0])
         | 
| 509 | 
            +
                    elif args.infer_prompts[0].endswith(".txt"):
         | 
| 510 | 
            +
                        with open(args.infer_prompts[0], 'r') as f:
         | 
| 511 | 
            +
                            args.infer_prompts = f.readlines()
         | 
| 512 | 
            +
                            args.infer_prompts = [prompt.strip() for prompt in args.infer_prompts]
         | 
| 513 | 
            +
                    elif args.infer_prompts[0].endswith(".csv"):
         | 
| 514 | 
            +
                        from custom_datasets.custom_caption import Caption_set
         | 
| 515 | 
            +
                        caption_set = Caption_set(args.infer_prompts[0])
         | 
| 516 | 
            +
                        args.infer_prompts = caption_set
         | 
| 517 | 
            +
             | 
| 518 | 
            +
             | 
| 519 | 
            +
                if args.infer_mode == "style":
         | 
| 520 | 
            +
                    with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f:
         | 
| 521 | 
            +
                        args.style_label = f.readlines()[0].strip()
         | 
| 522 | 
            +
                elif args.infer_mode == "artist":
         | 
| 523 | 
            +
                    with open(os.path.join(args.ref_image_folder, "style_label.txt"), 'r') as f:
         | 
| 524 | 
            +
                        args.style_label = f.readlines()[0].strip()
         | 
| 525 | 
            +
                        args.style_label = args.style_label.split(",")[0].strip()
         | 
| 526 | 
            +
                elif args.infer_mode == "ori":
         | 
| 527 | 
            +
                    args.style_label = None
         | 
| 528 | 
            +
                else:
         | 
| 529 | 
            +
                    args.style_label = args.infer_mode.replace("_", " ")
         | 
| 530 | 
            +
                if args.ref_image_folder is not None:
         | 
| 531 | 
            +
                    args.ref_image_folder = os.path.join(args.ref_image_folder, "paintings")
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                if args.start_noise < 0:
         | 
| 534 | 
            +
                    args.from_scratch = True
         | 
| 535 | 
            +
             | 
| 536 | 
            +
             | 
| 537 | 
            +
                print(args.__dict__)
         | 
| 538 | 
            +
                return args
         | 
| 539 | 
            +
             | 
| 540 | 
            +
             | 
| 541 | 
            +
            def main(args):
         | 
| 542 | 
            +
                lora_weights = args.lora_weights
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                if len(lora_weights) == 1 and isinstance(lora_weights[0], str) and os.path.isdir(lora_weights[0]):
         | 
| 545 | 
            +
                    lora_weights = glob.glob(os.path.join(lora_weights[0], "*.pt"))
         | 
| 546 | 
            +
                    lora_weights=sorted(lora_weights, reverse=True)
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                width = args.resolution
         | 
| 549 | 
            +
                height = args.resolution
         | 
| 550 | 
            +
                steps = args.infer_steps
         | 
| 551 | 
            +
             | 
| 552 | 
            +
                revision = None
         | 
| 553 | 
            +
                device = 'cuda'
         | 
| 554 | 
            +
                rank = args.rank
         | 
| 555 | 
            +
                if args.weight_dtype == "fp32":
         | 
| 556 | 
            +
                    weight_dtype = torch.float32
         | 
| 557 | 
            +
                elif args.weight_dtype=="fp16":
         | 
| 558 | 
            +
                    weight_dtype = torch.float16
         | 
| 559 | 
            +
                elif args.weight_dtype=="bf16":
         | 
| 560 | 
            +
                    weight_dtype = torch.bfloat16
         | 
| 561 | 
            +
             | 
| 562 | 
            +
                modules = get_model(args.pretrained_model_name_or_path, unet_ckpt=args.unet_ckpt, revision=revision, variant=None, lora_path=None, weight_dtype=weight_dtype, device=device, )
         | 
| 563 | 
            +
                if args.scheduler_type == "pndm":
         | 
| 564 | 
            +
                    noise_scheduler = PNDMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
         | 
| 565 | 
            +
             | 
| 566 | 
            +
                elif args.scheduler_type == "ddpm":
         | 
| 567 | 
            +
                    noise_scheduler = DDPMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
         | 
| 568 | 
            +
                elif args.scheduler_type == "ddim":
         | 
| 569 | 
            +
                    noise_scheduler = DDIMScheduler(
         | 
| 570 | 
            +
                        beta_start=0.00085,
         | 
| 571 | 
            +
                        beta_end=0.012,
         | 
| 572 | 
            +
                        beta_schedule="scaled_linear",
         | 
| 573 | 
            +
                        num_train_timesteps=1000,
         | 
| 574 | 
            +
                        clip_sample=False,
         | 
| 575 | 
            +
                        prediction_type="epsilon",
         | 
| 576 | 
            +
                    )
         | 
| 577 | 
            +
                elif args.scheduler_type == "lms":
         | 
| 578 | 
            +
                    noise_scheduler = LMSDiscreteScheduler(beta_start=0.00085,
         | 
| 579 | 
            +
                                         beta_end=0.012,
         | 
| 580 | 
            +
                                         beta_schedule="scaled_linear",
         | 
| 581 | 
            +
                                         num_train_timesteps=1000)
         | 
| 582 | 
            +
                else:
         | 
| 583 | 
            +
                    raise ValueError("Unknown scheduler type")
         | 
| 584 | 
            +
                cache=EasyDict()
         | 
| 585 | 
            +
                cache.modules = modules
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                unet = modules["unet"]
         | 
| 588 | 
            +
                vae = modules["vae"]
         | 
| 589 | 
            +
                text_encoder = modules["text_encoder"]
         | 
| 590 | 
            +
                tokenizer = modules["tokenizer"]
         | 
| 591 | 
            +
             | 
| 592 | 
            +
                unet.requires_grad_(False)
         | 
| 593 | 
            +
             | 
| 594 | 
            +
                # Move unet, vae and text_encoder to device and cast to weight_dtype
         | 
| 595 | 
            +
                vae.requires_grad_(False)
         | 
| 596 | 
            +
                text_encoder.requires_grad_(False)
         | 
| 597 | 
            +
             | 
| 598 | 
            +
                ## dataloader
         | 
| 599 | 
            +
                dataloader = get_validation_dataloader(infer_prompts=args.infer_prompts, infer_images=args.infer_images,
         | 
| 600 | 
            +
                                                       resolution=args.resolution,
         | 
| 601 | 
            +
                                                       batch_size=args.batch_size, num_workers=args.num_workers,
         | 
| 602 | 
            +
                                                       val_set=args.val_set)
         | 
| 603 | 
            +
             | 
| 604 | 
            +
             | 
| 605 | 
            +
                for lora_weight in lora_weights:
         | 
| 606 | 
            +
                    print(f"Testing {lora_weight}")
         | 
| 607 | 
            +
                    # for different seeds on same prompt
         | 
| 608 | 
            +
                    seed = args.seed
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                    network_ret = get_lora_network(unet, lora_weight, train_method=args.train_method, rank=rank, alpha=1.0, device=device, weight_dtype=weight_dtype)
         | 
| 611 | 
            +
                    network = network_ret["network"]
         | 
| 612 | 
            +
                    train_method = network_ret["train_method"]
         | 
| 613 | 
            +
                    if args.save_dir is not None:
         | 
| 614 | 
            +
                        save_dir = args.save_dir
         | 
| 615 | 
            +
                        if args.style_label is not None:
         | 
| 616 | 
            +
                            save_dir = os.path.join(save_dir, f"{args.style_label.replace(' ', '_')}")
         | 
| 617 | 
            +
                        else:
         | 
| 618 | 
            +
                            save_dir = os.path.join(save_dir, f"ori/{args.start_noise}")
         | 
| 619 | 
            +
                    else:
         | 
| 620 | 
            +
                        if args.folder_name is not None:
         | 
| 621 | 
            +
                            folder_name = args.folder_name
         | 
| 622 | 
            +
                        else:
         | 
| 623 | 
            +
                            folder_name = "validation" if args.infer_prompts is None else "validation_prompts"
         | 
| 624 | 
            +
                        save_dir = os.path.join(os.path.dirname(lora_weight), f"{folder_name}/{train_method}", os.path.basename(lora_weight).replace('.pt','').split('_')[-1])
         | 
| 625 | 
            +
                    if args.infer_prompts is None:
         | 
| 626 | 
            +
                        save_dir = os.path.join(save_dir, f"{args.val_set}")
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                    infer_config = f"{args.scheduler_type}{args.infer_steps}_{args.weight_dtype}_guidance{args.guidance_scale}"
         | 
| 629 | 
            +
                    save_dir = os.path.join(save_dir, infer_config)
         | 
| 630 | 
            +
                    os.makedirs(save_dir, exist_ok=True)
         | 
| 631 | 
            +
                    if args.from_scratch:
         | 
| 632 | 
            +
                        save_dir = os.path.join(save_dir, "from_scratch")
         | 
| 633 | 
            +
                    else:
         | 
| 634 | 
            +
                        save_dir = os.path.join(save_dir, "transfer")
         | 
| 635 | 
            +
                    save_dir = os.path.join(save_dir, f"start{args.start_noise}")
         | 
| 636 | 
            +
                    os.makedirs(save_dir, exist_ok=True)
         | 
| 637 | 
            +
                    with open(os.path.join(save_dir, "infer_args.yaml"), 'w') as f:
         | 
| 638 | 
            +
                        yaml.dump(vars(args), f)
         | 
| 639 | 
            +
                    # save code
         | 
| 640 | 
            +
                    code_dir = os.path.join(save_dir, "code")
         | 
| 641 | 
            +
                    os.makedirs(code_dir, exist_ok=True)
         | 
| 642 | 
            +
                    current_file = os.path.basename(__file__)
         | 
| 643 | 
            +
                    shutil.copy(__file__, os.path.join(code_dir, current_file))
         | 
| 644 | 
            +
                    with torch.no_grad():
         | 
| 645 | 
            +
                        pred_images, prompts = inference(network, tokenizer, text_encoder, vae, unet, noise_scheduler, dataloader, height, width,
         | 
| 646 | 
            +
                                                args.scales, save_dir, seed, weight_dtype, device, args.batch_size, steps, guidance_scale=args.guidance_scale,
         | 
| 647 | 
            +
                                                start_noise=args.start_noise, show=args.show, style_prompt=args.style_label, no_load=args.no_load,
         | 
| 648 | 
            +
                                                from_scratch=args.from_scratch)
         | 
| 649 | 
            +
             | 
| 650 | 
            +
                        if args.ref_image_folder is not None:
         | 
| 651 | 
            +
                            flush()
         | 
| 652 | 
            +
                            print("Calculating metrics")
         | 
| 653 | 
            +
                            infer_metric(args.ref_image_folder, pred_images, save_dir, args.start_noise)
         | 
| 654 | 
            +
             | 
| 655 | 
            +
            if __name__ == "__main__":
         | 
| 656 | 
            +
                args = parse_args()
         | 
| 657 | 
            +
                main(args)
         | 
    	
        utils/__init__.py
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
    	
        utils/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | Binary file (203 Bytes). View file | 
|  | 
    	
        utils/__pycache__/lora.cpython-39.pyc
    ADDED
    
    | Binary file (6.29 kB). View file | 
|  | 
    	
        utils/__pycache__/metrics.cpython-39.pyc
    ADDED
    
    | Binary file (19.3 kB). View file | 
|  | 
    	
        utils/__pycache__/train_util.cpython-39.pyc
    ADDED
    
    | Binary file (10.9 kB). View file | 
|  | 
    	
        utils/art_filter.py
    ADDED
    
    | @@ -0,0 +1,210 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from transformers import CLIPProcessor, CLIPModel
         | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import os
         | 
| 7 | 
            +
            os.environ["TOKENIZERS_PARALLELISM"] = "false"
         | 
| 8 | 
            +
            from tqdm import tqdm
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Caption_filter:
         | 
| 11 | 
            +
                def __init__(self, filter_prompts=["painting", "paintings", "art", "artwork", "drawings", "sketch", "sketches", "illustration", "illustrations",
         | 
| 12 | 
            +
                                                   "sculpture","sculptures", "installation", "printmaking", "digital art", "conceptual art", "mosaic", "tapestry",
         | 
| 13 | 
            +
                                                   "abstract", "realism", "surrealism", "impressionism", "expressionism", "cubism", "minimalism", "baroque", "rococo",
         | 
| 14 | 
            +
                                                   "pop art", "art nouveau", "art deco", "futurism", "dadaism",
         | 
| 15 | 
            +
                                                    "stamp", "stamps", "advertisement", "advertisements","logo", "logos"
         | 
| 16 | 
            +
                                                   ],):
         | 
| 17 | 
            +
                    self.filter_prompts = filter_prompts
         | 
| 18 | 
            +
                    self.total_count=0
         | 
| 19 | 
            +
                    self.filter_count=[0]*len(filter_prompts)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def reset(self):
         | 
| 22 | 
            +
                    self.total_count=0
         | 
| 23 | 
            +
                    self.filter_count=[0]*len(self.filter_prompts)
         | 
| 24 | 
            +
                def filter(self, captions):
         | 
| 25 | 
            +
                    filter_result = []
         | 
| 26 | 
            +
                    for caption in captions:
         | 
| 27 | 
            +
                        words = caption[0]
         | 
| 28 | 
            +
                        if words == None:
         | 
| 29 | 
            +
                            filter_result.append((True, "None"))
         | 
| 30 | 
            +
                            continue
         | 
| 31 | 
            +
                        words = words.lower()
         | 
| 32 | 
            +
                        words = words.split()
         | 
| 33 | 
            +
                        filt = False
         | 
| 34 | 
            +
                        reason=None
         | 
| 35 | 
            +
                        for i, filter_keyword in enumerate(self.filter_prompts):
         | 
| 36 | 
            +
                            key_len = len(filter_keyword.split())
         | 
| 37 | 
            +
                            for j in range(len(words)-key_len+1):
         | 
| 38 | 
            +
                                if " ".join(words[j:j+key_len]) == filter_keyword:
         | 
| 39 | 
            +
                                    self.filter_count[i] += 1
         | 
| 40 | 
            +
                                    filt = True
         | 
| 41 | 
            +
                                    reason = filter_keyword
         | 
| 42 | 
            +
                                    break
         | 
| 43 | 
            +
                            if filt:
         | 
| 44 | 
            +
                                break
         | 
| 45 | 
            +
                        filter_result.append((filt, reason))
         | 
| 46 | 
            +
                        self.total_count += 1
         | 
| 47 | 
            +
                    return filter_result
         | 
| 48 | 
            +
             | 
| 49 | 
            +
            class Clip_filter:
         | 
| 50 | 
            +
                prompt_threshold = {
         | 
| 51 | 
            +
                    "painting": 17,
         | 
| 52 | 
            +
                    "art": 17.5,
         | 
| 53 | 
            +
                    "artwork": 19,
         | 
| 54 | 
            +
                    "drawing": 15.8,
         | 
| 55 | 
            +
                    "sketch": 17,
         | 
| 56 | 
            +
                    "illustration": 15,
         | 
| 57 | 
            +
                    "sculpture": 19.2,
         | 
| 58 | 
            +
                    "installation art": 20,
         | 
| 59 | 
            +
                    "printmaking art": 16.3,
         | 
| 60 | 
            +
                    "digital art": 15,
         | 
| 61 | 
            +
                    "conceptual art": 18,
         | 
| 62 | 
            +
                    "mosaic art": 19,
         | 
| 63 | 
            +
                    "tapestry": 16,
         | 
| 64 | 
            +
                    "abstract art":16.5,
         | 
| 65 | 
            +
                    "realism art": 16,
         | 
| 66 | 
            +
                    "surrealism art": 15,
         | 
| 67 | 
            +
                    "impressionism art": 17,
         | 
| 68 | 
            +
                    "expressionism art": 17,
         | 
| 69 | 
            +
                    "cubism art": 15,
         | 
| 70 | 
            +
                    "minimalism art": 16,
         | 
| 71 | 
            +
                    "baroque art": 17.5,
         | 
| 72 | 
            +
                    "rococo art": 17,
         | 
| 73 | 
            +
                    "pop art": 16,
         | 
| 74 | 
            +
                    "art nouveau": 19,
         | 
| 75 | 
            +
                    "art deco": 19,
         | 
| 76 | 
            +
                    "futurism art": 16.5,
         | 
| 77 | 
            +
                    "dadaism art": 16.5,
         | 
| 78 | 
            +
                    "stamp": 18,
         | 
| 79 | 
            +
                    "advertisement": 16.5,
         | 
| 80 | 
            +
                    "logo": 15.5,
         | 
| 81 | 
            +
                }
         | 
| 82 | 
            +
                @torch.no_grad()
         | 
| 83 | 
            +
                def __init__(self, positive_prompt=["painting", "art", "artwork", "drawing", "sketch", "illustration",
         | 
| 84 | 
            +
                                                    "sculpture", "installation art", "printmaking art", "digital art", "conceptual art", "mosaic art", "tapestry",
         | 
| 85 | 
            +
                                                    "abstract art", "realism art", "surrealism art", "impressionism art", "expressionism art", "cubism art",
         | 
| 86 | 
            +
                                                    "minimalism art", "baroque art", "rococo art",
         | 
| 87 | 
            +
                                                    "pop art", "art nouveau", "art deco", "futurism art", "dadaism art",
         | 
| 88 | 
            +
                                                    "stamp", "advertisement",
         | 
| 89 | 
            +
                                                    "logo"
         | 
| 90 | 
            +
                                                    ],
         | 
| 91 | 
            +
                              device="cuda"):
         | 
| 92 | 
            +
                    self.device = device
         | 
| 93 | 
            +
                    self.model = (CLIPModel.from_pretrained("openai/clip-vit-large-patch14")).to(device)
         | 
| 94 | 
            +
                    self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
         | 
| 95 | 
            +
                    self.positive_prompt = positive_prompt
         | 
| 96 | 
            +
                    self.text = self.positive_prompt
         | 
| 97 | 
            +
                    self.tokenizer = self.processor.tokenizer
         | 
| 98 | 
            +
                    self.image_processor = self.processor.image_processor
         | 
| 99 | 
            +
                    self.text_encoding = self.tokenizer(self.text, return_tensors="pt", padding=True).to(device)
         | 
| 100 | 
            +
                    self.text_features = self.model.get_text_features(**self.text_encoding)
         | 
| 101 | 
            +
                    self.text_features = self.text_features / self.text_features.norm(p=2, dim=-1, keepdim=True)
         | 
| 102 | 
            +
                @torch.no_grad()
         | 
| 103 | 
            +
                def similarity(self, image):
         | 
| 104 | 
            +
                    # inputs = self.processor(text=self.text, images=image, return_tensors="pt", padding=True)
         | 
| 105 | 
            +
                    image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
         | 
| 106 | 
            +
                    inputs = {**self.text_encoding, **image_processed}
         | 
| 107 | 
            +
                    outputs = self.model(**inputs)
         | 
| 108 | 
            +
                    logits_per_image = outputs.logits_per_image
         | 
| 109 | 
            +
                    return logits_per_image
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def get_logits(self, image):
         | 
| 112 | 
            +
                    logits_per_image = self.similarity(image)
         | 
| 113 | 
            +
                    return logits_per_image.cpu()
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def get_image_features(self, image):
         | 
| 116 | 
            +
                    image_processed = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
         | 
| 117 | 
            +
                    image_features = self.model.get_image_features(**image_processed)
         | 
| 118 | 
            +
                    return image_features
         | 
| 119 | 
            +
             | 
| 120 | 
            +
             | 
| 121 | 
            +
            class Art_filter:
         | 
| 122 | 
            +
                def __init__(self):
         | 
| 123 | 
            +
                    self.caption_filter = Caption_filter()
         | 
| 124 | 
            +
                    self.clip_filter = Clip_filter()
         | 
| 125 | 
            +
                def caption_filt(self, dataloader):
         | 
| 126 | 
            +
                    self.caption_filter.reset()
         | 
| 127 | 
            +
                    dataloader.dataset.get_img = False
         | 
| 128 | 
            +
                    dataloader.dataset.get_cap = True
         | 
| 129 | 
            +
                    remain_ids = []
         | 
| 130 | 
            +
                    filtered_ids = []
         | 
| 131 | 
            +
                    for i, batch in tqdm(enumerate(dataloader)):
         | 
| 132 | 
            +
                        captions = batch["text"]
         | 
| 133 | 
            +
                        filter_result = self.caption_filter.filter(captions)
         | 
| 134 | 
            +
                        for j, (filt, reason) in enumerate(filter_result):
         | 
| 135 | 
            +
                            if filt:
         | 
| 136 | 
            +
                                filtered_ids.append((batch["ids"][j], reason))
         | 
| 137 | 
            +
                                if i%10==0:
         | 
| 138 | 
            +
                                    print(f"Filtered caption: {captions[j]}, reason: {reason}")
         | 
| 139 | 
            +
                            else:
         | 
| 140 | 
            +
                                remain_ids.append(batch["ids"][j])
         | 
| 141 | 
            +
                    return {"remain_ids":remain_ids, "filtered_ids":filtered_ids, "total_count":self.caption_filter.total_count, "filter_count":self.caption_filter.filter_count, "filter_prompts":self.caption_filter.filter_prompts}
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                def clip_filt(self, clip_logits_ckpt:dict):
         | 
| 144 | 
            +
                    logits = clip_logits_ckpt["clip_logits"]
         | 
| 145 | 
            +
                    ids = clip_logits_ckpt["ids"]
         | 
| 146 | 
            +
                    text = clip_logits_ckpt["text"]
         | 
| 147 | 
            +
                    filt_mask = torch.zeros(logits.shape[0], dtype=torch.bool)
         | 
| 148 | 
            +
                    for i, prompt in enumerate(text):
         | 
| 149 | 
            +
                        threshold = Clip_filter.prompt_threshold[prompt]
         | 
| 150 | 
            +
                        filt_mask = filt_mask | (logits[:,i] >= threshold)
         | 
| 151 | 
            +
                    filt_ids = []
         | 
| 152 | 
            +
                    remain_ids = []
         | 
| 153 | 
            +
                    for i, id in enumerate(ids):
         | 
| 154 | 
            +
                        if filt_mask[i]:
         | 
| 155 | 
            +
                            filt_ids.append(id)
         | 
| 156 | 
            +
                        else:
         | 
| 157 | 
            +
                            remain_ids.append(id)
         | 
| 158 | 
            +
                    return {"remain_ids":remain_ids, "filtered_ids":filt_ids}
         | 
| 159 | 
            +
             | 
| 160 | 
            +
                def clip_feature(self, dataloader):
         | 
| 161 | 
            +
                    dataloader.dataset.get_img = True
         | 
| 162 | 
            +
                    dataloader.dataset.get_cap = False
         | 
| 163 | 
            +
                    clip_features = []
         | 
| 164 | 
            +
                    ids = []
         | 
| 165 | 
            +
                    for i, batch in enumerate(dataloader):
         | 
| 166 | 
            +
                        images = batch["images"]
         | 
| 167 | 
            +
                        features = self.clip_filter.get_image_features(images).cpu()
         | 
| 168 | 
            +
                        clip_features.append(features)
         | 
| 169 | 
            +
                        ids.extend(batch["ids"])
         | 
| 170 | 
            +
                    clip_features = torch.cat(clip_features)
         | 
| 171 | 
            +
                    return {"clip_features":clip_features, "ids":ids}
         | 
| 172 | 
            +
             | 
| 173 | 
            +
             | 
| 174 | 
            +
                def clip_logit(self, dataloader):
         | 
| 175 | 
            +
                    dataloader.dataset.get_img = True
         | 
| 176 | 
            +
                    dataloader.dataset.get_cap = False
         | 
| 177 | 
            +
                    clip_features = []
         | 
| 178 | 
            +
                    clip_logits = []
         | 
| 179 | 
            +
                    ids = []
         | 
| 180 | 
            +
                    for i, batch in enumerate(dataloader):
         | 
| 181 | 
            +
                        images = batch["images"]
         | 
| 182 | 
            +
                        # logits = self.clip_filter.get_logits(images)
         | 
| 183 | 
            +
                        feature = self.clip_filter.get_image_features(images)
         | 
| 184 | 
            +
                        logits = self.clip_logit_by_feat(feature)["clip_logits"]
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                        clip_features.append(feature)
         | 
| 187 | 
            +
                        clip_logits.append(logits)
         | 
| 188 | 
            +
                        ids.extend(batch["ids"])
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                    clip_features = torch.cat(clip_features)
         | 
| 191 | 
            +
                    clip_logits = torch.cat(clip_logits)
         | 
| 192 | 
            +
                    return {"clip_features":clip_features, "clip_logits":clip_logits, "ids":ids, "text": self.clip_filter.text}
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                def clip_logit_by_feat(self, feature):
         | 
| 195 | 
            +
                    feature = feature.clone().to(self.clip_filter.device)
         | 
| 196 | 
            +
                    feature = feature / feature.norm(p=2, dim=-1, keepdim=True)
         | 
| 197 | 
            +
                    logit_scale = self.clip_filter.model.logit_scale.exp()
         | 
| 198 | 
            +
                    logits = ((feature @ self.clip_filter.text_features.T) * logit_scale).cpu()
         | 
| 199 | 
            +
                    return {"clip_logits":logits, "text": self.clip_filter.text}
         | 
| 200 | 
            +
             | 
| 201 | 
            +
             | 
| 202 | 
            +
             | 
| 203 | 
            +
            if __name__ == "__main__":
         | 
| 204 | 
            +
                import pickle
         | 
| 205 | 
            +
                with open("/vision-nfs/torralba/scratch/jomat/sam_dataset/filt_result/sa_000000/clip_logits_result.pickle","rb") as f:
         | 
| 206 | 
            +
                    result=pickle.load(f)
         | 
| 207 | 
            +
                feat = result['clip_features']
         | 
| 208 | 
            +
                logits =Art_filter().clip_logit_by_feat(feat)
         | 
| 209 | 
            +
                print(logits)
         | 
| 210 | 
            +
             | 
    	
        utils/config_util.py
    ADDED
    
    | @@ -0,0 +1,105 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Literal, Optional
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from pydantic import BaseModel
         | 
| 6 | 
            +
            import torch
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from lora import TRAINING_METHODS
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            PRECISION_TYPES = Literal["fp32", "fp16", "bf16", "float32", "float16", "bfloat16"]
         | 
| 11 | 
            +
            NETWORK_TYPES = Literal["lierla", "c3lier"]
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class PretrainedModelConfig(BaseModel):
         | 
| 15 | 
            +
                name_or_path: str
         | 
| 16 | 
            +
                ckpt_path: Optional[str] = None
         | 
| 17 | 
            +
                v2: bool = False
         | 
| 18 | 
            +
                v_pred: bool = False
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                clip_skip: Optional[int] = None
         | 
| 21 | 
            +
             | 
| 22 | 
            +
             | 
| 23 | 
            +
            class NetworkConfig(BaseModel):
         | 
| 24 | 
            +
                type: NETWORK_TYPES = "lierla"
         | 
| 25 | 
            +
                rank: int = 4
         | 
| 26 | 
            +
                alpha: float = 1.0
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                training_method: TRAINING_METHODS = "full"
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class TrainConfig(BaseModel):
         | 
| 32 | 
            +
                precision: PRECISION_TYPES = "bfloat16"
         | 
| 33 | 
            +
                noise_scheduler: Literal["ddim", "ddpm", "lms", "euler_a"] = "ddim"
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                iterations: int = 500
         | 
| 36 | 
            +
                lr: float = 1e-4
         | 
| 37 | 
            +
                optimizer: str = "adamw"
         | 
| 38 | 
            +
                optimizer_args: str = ""
         | 
| 39 | 
            +
                lr_scheduler: str = "constant"
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                max_denoising_steps: int = 50
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class SaveConfig(BaseModel):
         | 
| 45 | 
            +
                name: str = "untitled"
         | 
| 46 | 
            +
                path: str = "./output"
         | 
| 47 | 
            +
                per_steps: int = 200
         | 
| 48 | 
            +
                precision: PRECISION_TYPES = "float32"
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            class LoggingConfig(BaseModel):
         | 
| 52 | 
            +
                use_wandb: bool = False
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                verbose: bool = False
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            class OtherConfig(BaseModel):
         | 
| 58 | 
            +
                use_xformers: bool = False
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            class RootConfig(BaseModel):
         | 
| 62 | 
            +
                # prompts_file: str
         | 
| 63 | 
            +
                pretrained_model: PretrainedModelConfig
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                network: NetworkConfig
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                train: Optional[TrainConfig]
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                save: Optional[SaveConfig]
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                logging: Optional[LoggingConfig]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                other: Optional[OtherConfig]
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def parse_precision(precision: str) -> torch.dtype:
         | 
| 77 | 
            +
                if precision == "fp32" or precision == "float32":
         | 
| 78 | 
            +
                    return torch.float32
         | 
| 79 | 
            +
                elif precision == "fp16" or precision == "float16":
         | 
| 80 | 
            +
                    return torch.float16
         | 
| 81 | 
            +
                elif precision == "bf16" or precision == "bfloat16":
         | 
| 82 | 
            +
                    return torch.bfloat16
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                raise ValueError(f"Invalid precision type: {precision}")
         | 
| 85 | 
            +
             | 
| 86 | 
            +
             | 
| 87 | 
            +
            def load_config_from_yaml(config_path: str) -> RootConfig:
         | 
| 88 | 
            +
                with open(config_path, "r") as f:
         | 
| 89 | 
            +
                    config = yaml.load(f, Loader=yaml.FullLoader)
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                root = RootConfig(**config)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
                if root.train is None:
         | 
| 94 | 
            +
                    root.train = TrainConfig()
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                if root.save is None:
         | 
| 97 | 
            +
                    root.save = SaveConfig()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                if root.logging is None:
         | 
| 100 | 
            +
                    root.logging = LoggingConfig()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                if root.other is None:
         | 
| 103 | 
            +
                    root.other = OtherConfig()
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                return root
         | 
    	
        utils/debug_util.py
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # デバッグ用...
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def check_requires_grad(model: torch.nn.Module):
         | 
| 7 | 
            +
                for name, module in list(model.named_modules())[:5]:
         | 
| 8 | 
            +
                    if len(list(module.parameters())) > 0:
         | 
| 9 | 
            +
                        print(f"Module: {name}")
         | 
| 10 | 
            +
                        for name, param in list(module.named_parameters())[:2]:
         | 
| 11 | 
            +
                            print(f"    Parameter: {name}, Requires Grad: {param.requires_grad}")
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            def check_training_mode(model: torch.nn.Module):
         | 
| 15 | 
            +
                for name, module in list(model.named_modules())[:5]:
         | 
| 16 | 
            +
                    print(f"Module: {name}, Training Mode: {module.training}")
         | 
    	
        utils/lora.py
    ADDED
    
    | @@ -0,0 +1,282 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # ref:
         | 
| 2 | 
            +
            # - https://github.com/cloneofsimo/lora/blob/master/lora_diffusion/lora.py
         | 
| 3 | 
            +
            # - https://github.com/kohya-ss/sd-scripts/blob/main/networks/lora.py
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import os
         | 
| 6 | 
            +
            import math
         | 
| 7 | 
            +
            from typing import Optional, List, Type, Set, Literal
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import torch
         | 
| 10 | 
            +
            import torch.nn as nn
         | 
| 11 | 
            +
            from diffusers import UNet2DConditionModel
         | 
| 12 | 
            +
            from safetensors.torch import save_file
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            UNET_TARGET_REPLACE_MODULE_TRANSFORMER = [
         | 
| 16 | 
            +
            #     "Transformer2DModel",  # どうやらこっちの方らしい? # attn1, 2
         | 
| 17 | 
            +
                "Attention"
         | 
| 18 | 
            +
            ]
         | 
| 19 | 
            +
            UNET_TARGET_REPLACE_MODULE_CONV = [
         | 
| 20 | 
            +
                "ResnetBlock2D",
         | 
| 21 | 
            +
                "Downsample2D",
         | 
| 22 | 
            +
                "Upsample2D",
         | 
| 23 | 
            +
                #     "DownBlock2D",
         | 
| 24 | 
            +
                #     "UpBlock2D"
         | 
| 25 | 
            +
            ]  # locon, 3clier
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            LORA_PREFIX_UNET = "lora_unet"
         | 
| 28 | 
            +
             | 
| 29 | 
            +
            DEFAULT_TARGET_REPLACE = UNET_TARGET_REPLACE_MODULE_TRANSFORMER
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            TRAINING_METHODS = Literal[
         | 
| 32 | 
            +
                "noxattn",  # train all layers except x-attns and time_embed layers
         | 
| 33 | 
            +
                "innoxattn",  # train all layers except self attention layers
         | 
| 34 | 
            +
                "selfattn",  # ESD-u, train only self attention layers
         | 
| 35 | 
            +
                "xattn",  # ESD-x, train only x attention layers
         | 
| 36 | 
            +
                "full",  #  train all layers
         | 
| 37 | 
            +
                "xattn-strict", # q and k values
         | 
| 38 | 
            +
                "noxattn-hspace",
         | 
| 39 | 
            +
                "noxattn-hspace-last",
         | 
| 40 | 
            +
                # "xlayer",
         | 
| 41 | 
            +
                # "outxattn",
         | 
| 42 | 
            +
                # "outsattn",
         | 
| 43 | 
            +
                # "inxattn",
         | 
| 44 | 
            +
                # "inmidsattn",
         | 
| 45 | 
            +
                # "selflayer",
         | 
| 46 | 
            +
            ]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            class LoRAModule(nn.Module):
         | 
| 50 | 
            +
                """
         | 
| 51 | 
            +
                replaces forward method of the original Linear, instead of replacing the original Linear module.
         | 
| 52 | 
            +
                """
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def __init__(
         | 
| 55 | 
            +
                    self,
         | 
| 56 | 
            +
                    lora_name,
         | 
| 57 | 
            +
                    org_module: nn.Module,
         | 
| 58 | 
            +
                    multiplier=1.0,
         | 
| 59 | 
            +
                    lora_dim=4,
         | 
| 60 | 
            +
                    alpha=1,
         | 
| 61 | 
            +
                ):
         | 
| 62 | 
            +
                    """if alpha == 0 or None, alpha is rank (no scaling)."""
         | 
| 63 | 
            +
                    super().__init__()
         | 
| 64 | 
            +
                    self.lora_name = lora_name
         | 
| 65 | 
            +
                    self.lora_dim = lora_dim
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    if "Linear" in org_module.__class__.__name__:
         | 
| 68 | 
            +
                        in_dim = org_module.in_features
         | 
| 69 | 
            +
                        out_dim = org_module.out_features
         | 
| 70 | 
            +
                        self.lora_down = nn.Linear(in_dim, lora_dim, bias=False)
         | 
| 71 | 
            +
                        self.lora_up = nn.Linear(lora_dim, out_dim, bias=False)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                    elif "Conv" in org_module.__class__.__name__:  # 一応
         | 
| 74 | 
            +
                        in_dim = org_module.in_channels
         | 
| 75 | 
            +
                        out_dim = org_module.out_channels
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                        self.lora_dim = min(self.lora_dim, in_dim, out_dim)
         | 
| 78 | 
            +
                        if self.lora_dim != lora_dim:
         | 
| 79 | 
            +
                            print(f"{lora_name} dim (rank) is changed to: {self.lora_dim}")
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                        kernel_size = org_module.kernel_size
         | 
| 82 | 
            +
                        stride = org_module.stride
         | 
| 83 | 
            +
                        padding = org_module.padding
         | 
| 84 | 
            +
                        self.lora_down = nn.Conv2d(
         | 
| 85 | 
            +
                            in_dim, self.lora_dim, kernel_size, stride, padding, bias=False
         | 
| 86 | 
            +
                        )
         | 
| 87 | 
            +
                        self.lora_up = nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    if type(alpha) == torch.Tensor:
         | 
| 90 | 
            +
                        alpha = alpha.detach().numpy()
         | 
| 91 | 
            +
                    alpha = lora_dim if alpha is None or alpha == 0 else alpha
         | 
| 92 | 
            +
                    self.scale = alpha / self.lora_dim
         | 
| 93 | 
            +
                    self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    # same as microsoft's
         | 
| 96 | 
            +
                    nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5))
         | 
| 97 | 
            +
                    nn.init.zeros_(self.lora_up.weight)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    self.multiplier = multiplier
         | 
| 100 | 
            +
                    self.org_module = org_module  # remove in applying
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                def apply_to(self):
         | 
| 103 | 
            +
                    self.org_forward = self.org_module.forward
         | 
| 104 | 
            +
                    self.org_module.forward = self.forward
         | 
| 105 | 
            +
                    del self.org_module
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def forward(self, x):
         | 
| 108 | 
            +
                    return (
         | 
| 109 | 
            +
                        self.org_forward(x)
         | 
| 110 | 
            +
                        + self.lora_up(self.lora_down(x)) * self.multiplier * self.scale
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
             | 
| 113 | 
            +
             | 
| 114 | 
            +
            class LoRANetwork(nn.Module):
         | 
| 115 | 
            +
                def __init__(
         | 
| 116 | 
            +
                    self,
         | 
| 117 | 
            +
                    unet: UNet2DConditionModel,
         | 
| 118 | 
            +
                    rank: int = 4,
         | 
| 119 | 
            +
                    multiplier: float = 1.0,
         | 
| 120 | 
            +
                    alpha: float = 1.0,
         | 
| 121 | 
            +
                    train_method: TRAINING_METHODS = "full",
         | 
| 122 | 
            +
                ) -> None:
         | 
| 123 | 
            +
                    super().__init__()
         | 
| 124 | 
            +
                    self.lora_scale = 1
         | 
| 125 | 
            +
                    self.multiplier = multiplier
         | 
| 126 | 
            +
                    self.lora_dim = rank
         | 
| 127 | 
            +
                    self.alpha = alpha
         | 
| 128 | 
            +
             | 
| 129 | 
            +
             | 
| 130 | 
            +
                    self.module = LoRAModule
         | 
| 131 | 
            +
             | 
| 132 | 
            +
             | 
| 133 | 
            +
                    self.unet_loras = self.create_modules(
         | 
| 134 | 
            +
                        LORA_PREFIX_UNET,
         | 
| 135 | 
            +
                        unet,
         | 
| 136 | 
            +
                        DEFAULT_TARGET_REPLACE,
         | 
| 137 | 
            +
                        self.lora_dim,
         | 
| 138 | 
            +
                        self.multiplier,
         | 
| 139 | 
            +
                        train_method=train_method,
         | 
| 140 | 
            +
                    )
         | 
| 141 | 
            +
                    print(f"create LoRA for U-Net: {len(self.unet_loras)} modules.")
         | 
| 142 | 
            +
             | 
| 143 | 
            +
             | 
| 144 | 
            +
                    lora_names = set()
         | 
| 145 | 
            +
                    for lora in self.unet_loras:
         | 
| 146 | 
            +
                        assert (
         | 
| 147 | 
            +
                            lora.lora_name not in lora_names
         | 
| 148 | 
            +
                        ), f"duplicated lora name: {lora.lora_name}. {lora_names}"
         | 
| 149 | 
            +
                        lora_names.add(lora.lora_name)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
             | 
| 152 | 
            +
                    for lora in self.unet_loras:
         | 
| 153 | 
            +
                        lora.apply_to()
         | 
| 154 | 
            +
                        self.add_module(
         | 
| 155 | 
            +
                            lora.lora_name,
         | 
| 156 | 
            +
                            lora,
         | 
| 157 | 
            +
                        )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    del unet
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                    torch.cuda.empty_cache()
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                def create_modules(
         | 
| 164 | 
            +
                    self,
         | 
| 165 | 
            +
                    prefix: str,
         | 
| 166 | 
            +
                    root_module: nn.Module,
         | 
| 167 | 
            +
                    target_replace_modules: List[str],
         | 
| 168 | 
            +
                    rank: int,
         | 
| 169 | 
            +
                    multiplier: float,
         | 
| 170 | 
            +
                    train_method: TRAINING_METHODS,
         | 
| 171 | 
            +
                ) -> list:
         | 
| 172 | 
            +
                    loras = []
         | 
| 173 | 
            +
                    names = []
         | 
| 174 | 
            +
                    for name, module in root_module.named_modules():
         | 
| 175 | 
            +
                        if train_method == "noxattn" or train_method == "noxattn-hspace" or train_method == "noxattn-hspace-last":  # Cross Attention と Time Embed 以外学習
         | 
| 176 | 
            +
                            if "attn2" in name or "time_embed" in name:
         | 
| 177 | 
            +
                                continue
         | 
| 178 | 
            +
                        elif train_method == "innoxattn":  # Cross Attention 以外学習
         | 
| 179 | 
            +
                            if "attn2" in name:
         | 
| 180 | 
            +
                                continue
         | 
| 181 | 
            +
                        elif train_method == "selfattn":  # Self Attention のみ学習
         | 
| 182 | 
            +
                            if "attn1" not in name:
         | 
| 183 | 
            +
                                continue
         | 
| 184 | 
            +
                        elif train_method == "xattn" or train_method == "xattn-strict":  # Cross Attention のみ学習
         | 
| 185 | 
            +
                            if "attn2" not in name:
         | 
| 186 | 
            +
                                continue
         | 
| 187 | 
            +
                        elif train_method == "attn":
         | 
| 188 | 
            +
                            if "attn1" not in name and "attn2" not in name:
         | 
| 189 | 
            +
                                continue
         | 
| 190 | 
            +
                        elif train_method == "full":
         | 
| 191 | 
            +
                            pass
         | 
| 192 | 
            +
                        # else:
         | 
| 193 | 
            +
                        #     raise NotImplementedError(
         | 
| 194 | 
            +
                        #         f"train_method: {train_method} is not implemented."
         | 
| 195 | 
            +
                        #     )
         | 
| 196 | 
            +
                        ##
         | 
| 197 | 
            +
                        # union condition(b-lora)
         | 
| 198 | 
            +
                        else:
         | 
| 199 | 
            +
                            discard = True
         | 
| 200 | 
            +
                            if "all_up" in train_method:
         | 
| 201 | 
            +
                                if "up_blocks" in name:
         | 
| 202 | 
            +
                                    discard = False
         | 
| 203 | 
            +
                            if "down_1" in train_method:
         | 
| 204 | 
            +
                                if not ("down_blocks.1" not in name or "attentions" not in name):
         | 
| 205 | 
            +
                                    discard = False
         | 
| 206 | 
            +
                            if "down_2" in train_method:
         | 
| 207 | 
            +
                                if not ("down_blocks.2" not in name or "attentions" not in name):
         | 
| 208 | 
            +
                                    discard = False
         | 
| 209 | 
            +
                            if "up_1" in train_method:
         | 
| 210 | 
            +
                                if not ("up_blocks.1" not in name or "attentions" not in name):
         | 
| 211 | 
            +
                                    discard = False
         | 
| 212 | 
            +
                            if "up_2" in train_method:
         | 
| 213 | 
            +
                                if not ("up_blocks.2" not in name or "attentions" not in name):
         | 
| 214 | 
            +
                                    discard = False
         | 
| 215 | 
            +
                            if discard:
         | 
| 216 | 
            +
                                continue
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        ##
         | 
| 219 | 
            +
                        if module.__class__.__name__ in target_replace_modules:
         | 
| 220 | 
            +
                            for child_name, child_module in module.named_modules():
         | 
| 221 | 
            +
                                if child_module.__class__.__name__ in ["Linear", "Conv2d", "LoRACompatibleLinear", "LoRACompatibleConv"]:
         | 
| 222 | 
            +
                                    if train_method == 'xattn-strict':
         | 
| 223 | 
            +
                                        if 'out' in child_name:
         | 
| 224 | 
            +
                                            continue
         | 
| 225 | 
            +
                                    if train_method == 'noxattn-hspace':
         | 
| 226 | 
            +
                                        if 'mid_block' not in name:
         | 
| 227 | 
            +
                                            continue
         | 
| 228 | 
            +
                                    if train_method == 'noxattn-hspace-last':
         | 
| 229 | 
            +
                                        if 'mid_block' not in name or '.1' not in name or 'conv2' not in child_name:
         | 
| 230 | 
            +
                                            continue
         | 
| 231 | 
            +
                                    lora_name = prefix + "." + name + "." + child_name
         | 
| 232 | 
            +
                                    lora_name = lora_name.replace(".", "_")
         | 
| 233 | 
            +
                                    # print(f"{lora_name}")
         | 
| 234 | 
            +
                                    lora = self.module(
         | 
| 235 | 
            +
                                        lora_name, child_module, multiplier, rank, self.alpha
         | 
| 236 | 
            +
                                    )
         | 
| 237 | 
            +
            #                         print(name, child_name)
         | 
| 238 | 
            +
            #                         print(child_module.weight.shape)
         | 
| 239 | 
            +
                                    loras.append(lora)
         | 
| 240 | 
            +
                                    names.append(lora_name)
         | 
| 241 | 
            +
            #         print(f'@@@@@@@@@@@@@@@@@@@@@@@@@@@@ \n {names}')
         | 
| 242 | 
            +
                    return loras
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def prepare_optimizer_params(self):
         | 
| 245 | 
            +
                    all_params = []
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                    if self.unet_loras:  # 実質これしかない
         | 
| 248 | 
            +
                        params = []
         | 
| 249 | 
            +
                        [params.extend(lora.parameters()) for lora in self.unet_loras]
         | 
| 250 | 
            +
                        param_data = {"params": params}
         | 
| 251 | 
            +
                        all_params.append(param_data)
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    return all_params
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def save_weights(self, file, dtype=None, metadata: Optional[dict] = None):
         | 
| 256 | 
            +
                    state_dict = self.state_dict()
         | 
| 257 | 
            +
             | 
| 258 | 
            +
                    if dtype is not None:
         | 
| 259 | 
            +
                        for key in list(state_dict.keys()):
         | 
| 260 | 
            +
                            v = state_dict[key]
         | 
| 261 | 
            +
                            v = v.detach().clone().to("cpu").to(dtype)
         | 
| 262 | 
            +
                            state_dict[key] = v
         | 
| 263 | 
            +
             | 
| 264 | 
            +
            #         for key in list(state_dict.keys()):
         | 
| 265 | 
            +
            #             if not key.startswith("lora"):
         | 
| 266 | 
            +
            #                 # lora以外除外
         | 
| 267 | 
            +
            #                 del state_dict[key]
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                    if os.path.splitext(file)[1] == ".safetensors":
         | 
| 270 | 
            +
                        save_file(state_dict, file, metadata)
         | 
| 271 | 
            +
                    else:
         | 
| 272 | 
            +
                        torch.save(state_dict, file)
         | 
| 273 | 
            +
                def set_lora_slider(self, scale):
         | 
| 274 | 
            +
                    self.lora_scale = scale
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                def __enter__(self):
         | 
| 277 | 
            +
                    for lora in self.unet_loras:
         | 
| 278 | 
            +
                        lora.multiplier = 1.0 * self.lora_scale
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                def __exit__(self, exc_type, exc_value, tb):
         | 
| 281 | 
            +
                    for lora in self.unet_loras:
         | 
| 282 | 
            +
                        lora.multiplier = 0
         | 
    	
        utils/metrics.py
    ADDED
    
    | @@ -0,0 +1,577 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Authors: Hui Ren (rhfeiyang.github.io)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            from torchvision import transforms
         | 
| 7 | 
            +
            import torch
         | 
| 8 | 
            +
            import torch.nn.functional as F
         | 
| 9 | 
            +
            import torch.nn as nn
         | 
| 10 | 
            +
            from torch.autograd import Function
         | 
| 11 | 
            +
            from PIL import Image
         | 
| 12 | 
            +
            from transformers import CLIPProcessor, CLIPModel
         | 
| 13 | 
            +
            from collections import OrderedDict
         | 
| 14 | 
            +
            from transformers import BatchFeature
         | 
| 15 | 
            +
            import clip
         | 
| 16 | 
            +
            import copy
         | 
| 17 | 
            +
            import lpips
         | 
| 18 | 
            +
            from transformers import ViTImageProcessor, ViTModel
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            ## CSD_CLIP
         | 
| 21 | 
            +
            def convert_weights_float(model: nn.Module):
         | 
| 22 | 
            +
                """Convert applicable model parameters to fp32"""
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                def _convert_weights_to_fp32(l):
         | 
| 25 | 
            +
                    if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
         | 
| 26 | 
            +
                        l.weight.data = l.weight.data.float()
         | 
| 27 | 
            +
                        if l.bias is not None:
         | 
| 28 | 
            +
                            l.bias.data = l.bias.data.float()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                    if isinstance(l, nn.MultiheadAttention):
         | 
| 31 | 
            +
                        for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
         | 
| 32 | 
            +
                            tensor = getattr(l, attr)
         | 
| 33 | 
            +
                            if tensor is not None:
         | 
| 34 | 
            +
                                tensor.data = tensor.data.float()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                    for name in ["text_projection", "proj"]:
         | 
| 37 | 
            +
                        if hasattr(l, name):
         | 
| 38 | 
            +
                            attr = getattr(l, name)
         | 
| 39 | 
            +
                            if attr is not None:
         | 
| 40 | 
            +
                                attr.data = attr.data.float()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                model.apply(_convert_weights_to_fp32)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
            class ReverseLayerF(Function):
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                @staticmethod
         | 
| 47 | 
            +
                def forward(ctx, x, alpha):
         | 
| 48 | 
            +
                    ctx.alpha = alpha
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    return x.view_as(x)
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                @staticmethod
         | 
| 53 | 
            +
                def backward(ctx, grad_output):
         | 
| 54 | 
            +
                    output = grad_output.neg() * ctx.alpha
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                    return output, None
         | 
| 57 | 
            +
             | 
| 58 | 
            +
             | 
| 59 | 
            +
            ## taken from https://github.com/moein-shariatnia/OpenAI-CLIP/blob/master/modules.py
         | 
| 60 | 
            +
            class ProjectionHead(nn.Module):
         | 
| 61 | 
            +
                def __init__(
         | 
| 62 | 
            +
                        self,
         | 
| 63 | 
            +
                        embedding_dim,
         | 
| 64 | 
            +
                        projection_dim,
         | 
| 65 | 
            +
                        dropout=0
         | 
| 66 | 
            +
                ):
         | 
| 67 | 
            +
                    super().__init__()
         | 
| 68 | 
            +
                    self.projection = nn.Linear(embedding_dim, projection_dim)
         | 
| 69 | 
            +
                    self.gelu = nn.GELU()
         | 
| 70 | 
            +
                    self.fc = nn.Linear(projection_dim, projection_dim)
         | 
| 71 | 
            +
                    self.dropout = nn.Dropout(dropout)
         | 
| 72 | 
            +
                    self.layer_norm = nn.LayerNorm(projection_dim)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def forward(self, x):
         | 
| 75 | 
            +
                    projected = self.projection(x)
         | 
| 76 | 
            +
                    x = self.gelu(projected)
         | 
| 77 | 
            +
                    x = self.fc(x)
         | 
| 78 | 
            +
                    x = self.dropout(x)
         | 
| 79 | 
            +
                    x = x + projected
         | 
| 80 | 
            +
                    x = self.layer_norm(x)
         | 
| 81 | 
            +
                    return x
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            def convert_state_dict(state_dict):
         | 
| 84 | 
            +
                new_state_dict = OrderedDict()
         | 
| 85 | 
            +
                for k, v in state_dict.items():
         | 
| 86 | 
            +
                    if k.startswith("module."):
         | 
| 87 | 
            +
                        k = k.replace("module.", "")
         | 
| 88 | 
            +
                    new_state_dict[k] = v
         | 
| 89 | 
            +
                return new_state_dict
         | 
| 90 | 
            +
            def init_weights(m):
         | 
| 91 | 
            +
                if isinstance(m, nn.Linear):
         | 
| 92 | 
            +
                    torch.nn.init.xavier_uniform_(m.weight)
         | 
| 93 | 
            +
                    if m.bias is not None:
         | 
| 94 | 
            +
                        nn.init.normal_(m.bias, std=1e-6)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            class Metric(nn.Module):
         | 
| 97 | 
            +
                def __init__(self):
         | 
| 98 | 
            +
                    super().__init__()
         | 
| 99 | 
            +
                    self.image_preprocess = None
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def load_image(self, image_path):
         | 
| 102 | 
            +
                    with open(image_path, 'rb') as f:
         | 
| 103 | 
            +
                        image = Image.open(f).convert("RGB")
         | 
| 104 | 
            +
                    return image
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def load_image_path(self, image_path):
         | 
| 107 | 
            +
                    if isinstance(image_path, str):
         | 
| 108 | 
            +
                        # should be a image folder path
         | 
| 109 | 
            +
                        images_file = os.listdir(image_path)
         | 
| 110 | 
            +
                        images = [os.path.join(image_path, image) for image in images_file if
         | 
| 111 | 
            +
                                  image.endswith(".jpg") or image.endswith(".png")]
         | 
| 112 | 
            +
                    if isinstance(image_path[0], str):
         | 
| 113 | 
            +
                        images = [self.load_image(image) for image in image_path]
         | 
| 114 | 
            +
                    elif isinstance(image_path[0], np.ndarray):
         | 
| 115 | 
            +
                        images = [Image.fromarray(image) for image in image_path]
         | 
| 116 | 
            +
                    elif isinstance(image_path[0], Image.Image):
         | 
| 117 | 
            +
                        images = image_path
         | 
| 118 | 
            +
                    else:
         | 
| 119 | 
            +
                        raise Exception("Invalid input")
         | 
| 120 | 
            +
                    return images
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def preprocess_image(self, image, **kwargs):
         | 
| 123 | 
            +
                    if (isinstance(image, str) and os.path.isdir(image)) or (isinstance(image, list) and (isinstance(image[0], Image.Image) or isinstance(image[0], np.ndarray) or os.path.isfile(image[0]))):
         | 
| 124 | 
            +
                        input_data = self.load_image_path(image)
         | 
| 125 | 
            +
                        input_data = [self.image_preprocess(image, **kwargs) for image in input_data]
         | 
| 126 | 
            +
                        input_data = torch.stack(input_data)
         | 
| 127 | 
            +
                    elif os.path.isfile(image):
         | 
| 128 | 
            +
                        input_data = self.load_image(image)
         | 
| 129 | 
            +
                        input_data = self.image_preprocess(input_data, **kwargs)
         | 
| 130 | 
            +
                        input_data = input_data.unsqueeze(0)
         | 
| 131 | 
            +
                    elif isinstance(image, torch.Tensor):
         | 
| 132 | 
            +
                        raise Exception("Unsupported input")
         | 
| 133 | 
            +
                    return input_data
         | 
| 134 | 
            +
             | 
| 135 | 
            +
            class Clip_Basic_Metric(Metric):
         | 
| 136 | 
            +
                def __init__(self):
         | 
| 137 | 
            +
                    super().__init__()
         | 
| 138 | 
            +
                    self.tensor_preprocess = transforms.Compose([
         | 
| 139 | 
            +
                        transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
         | 
| 140 | 
            +
                        # transforms.rescale
         | 
| 141 | 
            +
                        transforms.Normalize(mean=[-1.0, -1.0, -1.0], std=[2.0, 2.0, 2.0]),
         | 
| 142 | 
            +
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         | 
| 143 | 
            +
                    ])
         | 
| 144 | 
            +
                    self.image_preprocess = transforms.Compose([
         | 
| 145 | 
            +
                        transforms.Resize(size=224, interpolation=transforms.InterpolationMode.BICUBIC),
         | 
| 146 | 
            +
                        transforms.CenterCrop(224),
         | 
| 147 | 
            +
                        transforms.ToTensor(),
         | 
| 148 | 
            +
                        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
         | 
| 149 | 
            +
                    ])
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            class Clip_metric(Clip_Basic_Metric):
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                @torch.no_grad()
         | 
| 154 | 
            +
                def __init__(self, target_style_prompt: str=None, clip_model_name="openai/clip-vit-large-patch14", device="cuda",
         | 
| 155 | 
            +
                             bath_size=8, alpha=0.5):
         | 
| 156 | 
            +
                    super().__init__()
         | 
| 157 | 
            +
                    self.device = device
         | 
| 158 | 
            +
                    self.alpha = alpha
         | 
| 159 | 
            +
                    self.model = (CLIPModel.from_pretrained(clip_model_name)).to(device)
         | 
| 160 | 
            +
                    self.processor = CLIPProcessor.from_pretrained(clip_model_name)
         | 
| 161 | 
            +
                    self.tokenizer = self.processor.tokenizer
         | 
| 162 | 
            +
                    self.image_processor = self.processor.image_processor
         | 
| 163 | 
            +
                    # self.style_class_features = self.get_text_features(self.styles).cpu()
         | 
| 164 | 
            +
                    self.style_class_features=[]
         | 
| 165 | 
            +
                    # self.noise_prompt_features = self.get_text_features("Noise")
         | 
| 166 | 
            +
                    self.model.eval()
         | 
| 167 | 
            +
                    self.batch_size = bath_size
         | 
| 168 | 
            +
                    if target_style_prompt is not None:
         | 
| 169 | 
            +
                        self.ref_style_features = self.get_text_features(target_style_prompt)
         | 
| 170 | 
            +
                    else:
         | 
| 171 | 
            +
                        self.ref_style_features = None
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                    self.ref_image_style_prototype = None
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                def get_text_features(self, text):
         | 
| 176 | 
            +
                    prompt_encoding = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(self.device)
         | 
| 177 | 
            +
                    prompt_features = self.model.get_text_features(**prompt_encoding).to(self.device)
         | 
| 178 | 
            +
                    prompt_features = F.normalize(prompt_features, p=2, dim=-1)
         | 
| 179 | 
            +
                    return prompt_features
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                def get_image_features(self, images):
         | 
| 182 | 
            +
                    # if isinstance(image, torch.Tensor):
         | 
| 183 | 
            +
                    #     self.tensor_transform(image)
         | 
| 184 | 
            +
                    # else:
         | 
| 185 | 
            +
                    #     image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
         | 
| 186 | 
            +
                    images = self.load_image_path(images)
         | 
| 187 | 
            +
                    if isinstance(images, torch.Tensor):
         | 
| 188 | 
            +
                        images = self.tensor_preprocess(images)
         | 
| 189 | 
            +
                        data = {"pixel_values": images}
         | 
| 190 | 
            +
                        image_features = BatchFeature(data=data, tensor_type="pt")
         | 
| 191 | 
            +
                    else:
         | 
| 192 | 
            +
                        image_features = self.image_processor(images, return_tensors="pt", padding=True).to(self.device,
         | 
| 193 | 
            +
                                                                                                            non_blocking=True)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    image_features = self.model.get_image_features(**image_features).to(self.device)
         | 
| 196 | 
            +
                    image_features = F.normalize(image_features, p=2, dim=-1)
         | 
| 197 | 
            +
                    return image_features
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def img_text_similarity(self, image_features, text=None):
         | 
| 200 | 
            +
                    if text is not None:
         | 
| 201 | 
            +
                        prompt_feature = self.get_text_features(text)
         | 
| 202 | 
            +
                        if isinstance(text, str):
         | 
| 203 | 
            +
                            prompt_feature = prompt_feature.repeat(len(image_features), 1)
         | 
| 204 | 
            +
                    else:
         | 
| 205 | 
            +
                        prompt_feature = self.ref_style_features
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    similarity_each = torch.einsum("nc, nc -> n", image_features, prompt_feature)
         | 
| 208 | 
            +
                    return similarity_each
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def forward(self, output_imgs, prompt=None):
         | 
| 211 | 
            +
                    image_features = self.get_image_features(output_imgs)
         | 
| 212 | 
            +
                    # print(image_features)
         | 
| 213 | 
            +
                    style_score = self.img_text_similarity(image_features.mean(dim=0, keepdim=True))
         | 
| 214 | 
            +
                    if prompt is not None:
         | 
| 215 | 
            +
                        content_score = self.img_text_similarity(image_features, prompt)
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                        score = self.alpha * style_score + (1 - self.alpha) * content_score
         | 
| 218 | 
            +
                        return {"score": score, "style_score": style_score, "content_score": content_score}
         | 
| 219 | 
            +
                    else:
         | 
| 220 | 
            +
                        return {"style_score": style_score}
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def content_score(self, output_imgs, prompt):
         | 
| 223 | 
            +
                    self.to(self.device)
         | 
| 224 | 
            +
                    image_features = self.get_image_features(output_imgs)
         | 
| 225 | 
            +
                    content_score_details = self.img_text_similarity(image_features, prompt)
         | 
| 226 | 
            +
                    self.to("cpu")
         | 
| 227 | 
            +
                    return {"CLIP_content_score": content_score_details.mean().cpu(), "CLIP_content_score_details": content_score_details.cpu()}
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            class CSD_CLIP(Clip_Basic_Metric):
         | 
| 231 | 
            +
                """backbone + projection head"""
         | 
| 232 | 
            +
                def __init__(self, name='vit_large',content_proj_head='default', ckpt_path = "data/weights/CSD-checkpoint.pth", device="cuda",
         | 
| 233 | 
            +
                             alpha=0.5, **kwargs):
         | 
| 234 | 
            +
                    super(CSD_CLIP, self).__init__()
         | 
| 235 | 
            +
                    self.alpha = alpha
         | 
| 236 | 
            +
                    self.content_proj_head = content_proj_head
         | 
| 237 | 
            +
                    self.device = device
         | 
| 238 | 
            +
                    if name == 'vit_large':
         | 
| 239 | 
            +
                        clipmodel, _ = clip.load("ViT-L/14")
         | 
| 240 | 
            +
                        self.backbone = clipmodel.visual
         | 
| 241 | 
            +
                        self.embedding_dim = 1024
         | 
| 242 | 
            +
                    elif name == 'vit_base':
         | 
| 243 | 
            +
                        clipmodel, _ = clip.load("ViT-B/16")
         | 
| 244 | 
            +
                        self.backbone = clipmodel.visual
         | 
| 245 | 
            +
                        self.embedding_dim = 768
         | 
| 246 | 
            +
                        self.feat_dim = 512
         | 
| 247 | 
            +
                    else:
         | 
| 248 | 
            +
                        raise Exception('This model is not implemented')
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    convert_weights_float(self.backbone)
         | 
| 251 | 
            +
                    self.last_layer_style = copy.deepcopy(self.backbone.proj)
         | 
| 252 | 
            +
                    if content_proj_head == 'custom':
         | 
| 253 | 
            +
                        self.last_layer_content = ProjectionHead(self.embedding_dim,self.feat_dim)
         | 
| 254 | 
            +
                        self.last_layer_content.apply(init_weights)
         | 
| 255 | 
            +
             | 
| 256 | 
            +
                    else:
         | 
| 257 | 
            +
                        self.last_layer_content = copy.deepcopy(self.backbone.proj)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    self.backbone.proj = None
         | 
| 260 | 
            +
                    self.backbone.requires_grad_(False)
         | 
| 261 | 
            +
                    self.last_layer_style.requires_grad_(False)
         | 
| 262 | 
            +
                    self.last_layer_content.requires_grad_(False)
         | 
| 263 | 
            +
                    self.backbone.eval()
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                    if ckpt_path is not None:
         | 
| 266 | 
            +
                        self.load_ckpt(ckpt_path)
         | 
| 267 | 
            +
                    self.to("cpu")
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def load_ckpt(self, ckpt_path):
         | 
| 270 | 
            +
                    checkpoint = torch.load(ckpt_path, map_location="cpu")
         | 
| 271 | 
            +
                    state_dict = convert_state_dict(checkpoint['model_state_dict'])
         | 
| 272 | 
            +
                    msg = self.load_state_dict(state_dict, strict=False)
         | 
| 273 | 
            +
                    print(f"=> loaded CSD_CLIP checkpoint with msg {msg}")
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                @property
         | 
| 276 | 
            +
                def dtype(self):
         | 
| 277 | 
            +
                    return self.backbone.conv1.weight.dtype
         | 
| 278 | 
            +
             | 
| 279 | 
            +
                def get_image_features(self, input_data, get_style=True,get_content=False,feature_alpha=None):
         | 
| 280 | 
            +
                    if isinstance(input_data, torch.Tensor):
         | 
| 281 | 
            +
                        input_data = self.tensor_preprocess(input_data)
         | 
| 282 | 
            +
                    elif (isinstance(input_data, str) and os.path.isdir(input_data)) or (isinstance(input_data, list) and (isinstance(input_data[0], Image.Image) or isinstance(input_data[0], np.ndarray) or os.path.isfile(input_data[0]))):
         | 
| 283 | 
            +
                        input_data = self.load_image_path(input_data)
         | 
| 284 | 
            +
                        input_data = [self.image_preprocess(image) for image in input_data]
         | 
| 285 | 
            +
                        input_data = torch.stack(input_data)
         | 
| 286 | 
            +
                    elif os.path.isfile(input_data):
         | 
| 287 | 
            +
                        input_data = self.load_image(input_data)
         | 
| 288 | 
            +
                        input_data = self.image_preprocess(input_data)
         | 
| 289 | 
            +
                        input_data = input_data.unsqueeze(0)
         | 
| 290 | 
            +
                    input_data = input_data.to(self.device)
         | 
| 291 | 
            +
                    style_output = None
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    feature = self.backbone(input_data)
         | 
| 294 | 
            +
                    if get_style:
         | 
| 295 | 
            +
                        style_output = feature @ self.last_layer_style
         | 
| 296 | 
            +
                        # style_output = style_output.mean(dim=0)
         | 
| 297 | 
            +
                        style_output = nn.functional.normalize(style_output, dim=-1, p=2)
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                    content_output=None
         | 
| 300 | 
            +
                    if get_content:
         | 
| 301 | 
            +
                        if feature_alpha is not None:
         | 
| 302 | 
            +
                            reverse_feature = ReverseLayerF.apply(feature, feature_alpha)
         | 
| 303 | 
            +
                        else:
         | 
| 304 | 
            +
                            reverse_feature = feature
         | 
| 305 | 
            +
                        # if alpha is not None:
         | 
| 306 | 
            +
                        if self.content_proj_head == 'custom':
         | 
| 307 | 
            +
                            content_output =  self.last_layer_content(reverse_feature)
         | 
| 308 | 
            +
                        else:
         | 
| 309 | 
            +
                            content_output = reverse_feature @ self.last_layer_content
         | 
| 310 | 
            +
                        content_output = nn.functional.normalize(content_output, dim=-1, p=2)
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                    return feature, content_output, style_output
         | 
| 313 | 
            +
             | 
| 314 | 
            +
             | 
| 315 | 
            +
                @torch.no_grad()
         | 
| 316 | 
            +
                def define_ref_image_style_prototype(self, ref_image_path: str):
         | 
| 317 | 
            +
                    self.to(self.device)
         | 
| 318 | 
            +
                    _, _, self.ref_style_feature = self.get_image_features(ref_image_path)
         | 
| 319 | 
            +
                    self.to("cpu")
         | 
| 320 | 
            +
                    # self.ref_style_feature = self.ref_style_feature.mean(dim=0)
         | 
| 321 | 
            +
                @torch.no_grad()
         | 
| 322 | 
            +
                def forward(self, styled_data):
         | 
| 323 | 
            +
                    self.to(self.device)
         | 
| 324 | 
            +
                    # get_content_feature = original_data is not None
         | 
| 325 | 
            +
                    _, content_output, style_output = self.get_image_features(styled_data, get_content=False)
         | 
| 326 | 
            +
                    style_similarities = style_output @ self.ref_style_feature.T
         | 
| 327 | 
            +
                    mean_style_similarities = style_similarities.mean(dim=-1)
         | 
| 328 | 
            +
                    mean_style_similarity = mean_style_similarities.mean()
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                    max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
         | 
| 331 | 
            +
                    max_style_similarity = max_style_similarities_v.mean()
         | 
| 332 | 
            +
             | 
| 333 | 
            +
             | 
| 334 | 
            +
                    self.to("cpu")
         | 
| 335 | 
            +
                    return {"CSD_similarity_mean": mean_style_similarity, "CSD_similarity_max": max_style_similarity, "CSD_similarity_mean_details": mean_style_similarities,
         | 
| 336 | 
            +
                            "CSD_similarity_max_v_details": max_style_similarities_v, "CSD_similarity_max_id_details": max_style_similarities_id}
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                def get_style_loss(self, styled_data):
         | 
| 339 | 
            +
                    _, _, style_output = self.get_image_features(styled_data, get_style=True, get_content=False)
         | 
| 340 | 
            +
                    style_similarity = (style_output @ self.ref_style_feature).mean()
         | 
| 341 | 
            +
                    loss = 1 - style_similarity
         | 
| 342 | 
            +
                    return loss.mean()
         | 
| 343 | 
            +
             | 
| 344 | 
            +
            class LPIPS_metric(Metric):
         | 
| 345 | 
            +
                def __init__(self, type="vgg", device="cuda"):
         | 
| 346 | 
            +
                    super(LPIPS_metric, self).__init__()
         | 
| 347 | 
            +
                    self.lpips = lpips.LPIPS(net=type)
         | 
| 348 | 
            +
                    self.device = device
         | 
| 349 | 
            +
                    self.image_preprocess = transforms.Compose([
         | 
| 350 | 
            +
                        transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
         | 
| 351 | 
            +
                        transforms.CenterCrop(256),
         | 
| 352 | 
            +
                        transforms.ToTensor(),
         | 
| 353 | 
            +
                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
         | 
| 354 | 
            +
                    ])
         | 
| 355 | 
            +
                    self.to("cpu")
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                @torch.no_grad()
         | 
| 358 | 
            +
                def forward(self, img1, img2):
         | 
| 359 | 
            +
                    self.to(self.device)
         | 
| 360 | 
            +
                    differences = []
         | 
| 361 | 
            +
                    for i in range(0, len(img1), 50):
         | 
| 362 | 
            +
                        img1_batch = img1[i:i+50]
         | 
| 363 | 
            +
                        img2_batch = img2[i:i+50]
         | 
| 364 | 
            +
                        img1_batch = self.preprocess_image(img1_batch).to(self.device)
         | 
| 365 | 
            +
                        img2_batch = self.preprocess_image(img2_batch).to(self.device)
         | 
| 366 | 
            +
                        differences.append(self.lpips(img1_batch, img2_batch).squeeze())
         | 
| 367 | 
            +
                    differences = torch.cat(differences)
         | 
| 368 | 
            +
                    difference = differences.mean()
         | 
| 369 | 
            +
                    # similarity = 1 - difference
         | 
| 370 | 
            +
                    self.to("cpu")
         | 
| 371 | 
            +
                    return {"LPIPS_content_difference": difference,  "LPIPS_content_difference_details": differences}
         | 
| 372 | 
            +
             | 
| 373 | 
            +
            class Vit_metric(Metric):
         | 
| 374 | 
            +
                def __init__(self, device="cuda"):
         | 
| 375 | 
            +
                    super(Vit_metric, self).__init__()
         | 
| 376 | 
            +
                    self.device = device
         | 
| 377 | 
            +
                    self.model = ViTModel.from_pretrained('facebook/dino-vitb8').eval()
         | 
| 378 | 
            +
                    self.image_processor = ViTImageProcessor.from_pretrained('facebook/dino-vitb8')
         | 
| 379 | 
            +
                    self.to("cpu")
         | 
| 380 | 
            +
                def get_image_features(self, images):
         | 
| 381 | 
            +
                    # if isinstance(image, torch.Tensor):
         | 
| 382 | 
            +
                    #     self.tensor_transform(image)
         | 
| 383 | 
            +
                    # else:
         | 
| 384 | 
            +
                    #     image_features = self.image_processor(image, return_tensors="pt", padding=True).to(self.device, non_blocking=True)
         | 
| 385 | 
            +
                    images = self.load_image_path(images)
         | 
| 386 | 
            +
                    batch_size = 20
         | 
| 387 | 
            +
                    all_image_features = []
         | 
| 388 | 
            +
                    for i in range(0, len(images), batch_size):
         | 
| 389 | 
            +
                        image_batch = images[i:i+batch_size]
         | 
| 390 | 
            +
                        if isinstance(image_batch, torch.Tensor):
         | 
| 391 | 
            +
                            image_batch = self.tensor_preprocess(image_batch)
         | 
| 392 | 
            +
                            data = {"pixel_values": image_batch}
         | 
| 393 | 
            +
                            image_processed = BatchFeature(data=data, tensor_type="pt")
         | 
| 394 | 
            +
                        else:
         | 
| 395 | 
            +
                            image_processed = self.image_processor(image_batch, return_tensors="pt").to(self.device)
         | 
| 396 | 
            +
                        image_features = self.model(**image_processed).last_hidden_state.flatten(start_dim=1)
         | 
| 397 | 
            +
                        image_features = F.normalize(image_features, p=2, dim=-1)
         | 
| 398 | 
            +
                        all_image_features.append(image_features)
         | 
| 399 | 
            +
                    all_image_features = torch.cat(all_image_features)
         | 
| 400 | 
            +
                    return all_image_features
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                @torch.no_grad()
         | 
| 403 | 
            +
                def content_metric(self, img1, img2):
         | 
| 404 | 
            +
                    self.to(self.device)
         | 
| 405 | 
            +
                    if not(isinstance(img1, torch.Tensor) and len(img1.shape) == 2):
         | 
| 406 | 
            +
                        img1 = self.get_image_features(img1)
         | 
| 407 | 
            +
                    if not(isinstance(img2, torch.Tensor) and len(img2.shape) == 2):
         | 
| 408 | 
            +
                        img2 = self.get_image_features(img2)
         | 
| 409 | 
            +
                    similarities = torch.einsum("nc, nc -> n", img1, img2)
         | 
| 410 | 
            +
                    similarity = similarities.mean()
         | 
| 411 | 
            +
                    # self.to("cpu")
         | 
| 412 | 
            +
                    return {"Vit_content_similarity": similarity, "Vit_content_similarity_details": similarities}
         | 
| 413 | 
            +
             | 
| 414 | 
            +
                # style
         | 
| 415 | 
            +
                @torch.no_grad()
         | 
| 416 | 
            +
                def define_ref_image_style_prototype(self, ref_image_path: str):
         | 
| 417 | 
            +
                    self.to(self.device)
         | 
| 418 | 
            +
                    self.ref_style_feature = self.get_image_features(ref_image_path)
         | 
| 419 | 
            +
                    self.to("cpu")
         | 
| 420 | 
            +
                @torch.no_grad()
         | 
| 421 | 
            +
                def style_metric(self, styled_data):
         | 
| 422 | 
            +
                    self.to(self.device)
         | 
| 423 | 
            +
                    if isinstance(styled_data, torch.Tensor) and len(styled_data.shape) == 2:
         | 
| 424 | 
            +
                        style_output = styled_data
         | 
| 425 | 
            +
                    else:
         | 
| 426 | 
            +
                        style_output = self.get_image_features(styled_data)
         | 
| 427 | 
            +
                    style_similarities = style_output @ self.ref_style_feature.T
         | 
| 428 | 
            +
                    mean_style_similarities = style_similarities.mean(dim=-1)
         | 
| 429 | 
            +
                    mean_style_similarity = mean_style_similarities.mean()
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    max_style_similarities_v, max_style_similarities_id = style_similarities.max(dim=-1)
         | 
| 432 | 
            +
                    max_style_similarity = max_style_similarities_v.mean()
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    # self.to("cpu")
         | 
| 435 | 
            +
                    return {"Vit_style_similarity_mean": mean_style_similarity, "Vit_style_similarity_max": max_style_similarity, "Vit_style_similarity_mean_details": mean_style_similarities,
         | 
| 436 | 
            +
                            "Vit_style_similarity_max_v_details": max_style_similarities_v, "Vit_style_similarity_max_id_details": max_style_similarities_id}
         | 
| 437 | 
            +
                @torch.no_grad()
         | 
| 438 | 
            +
                def forward(self, styled_data, original_data=None):
         | 
| 439 | 
            +
                    self.to(self.device)
         | 
| 440 | 
            +
                    styled_features = self.get_image_features(styled_data)
         | 
| 441 | 
            +
                    ret ={}
         | 
| 442 | 
            +
                    if original_data is not None:
         | 
| 443 | 
            +
                        content_metric = self.content_metric(styled_features, original_data)
         | 
| 444 | 
            +
                        ret["Vit_content"] = content_metric
         | 
| 445 | 
            +
                    style_metric = self.style_metric(styled_features)
         | 
| 446 | 
            +
                    ret["Vit_style"] = style_metric
         | 
| 447 | 
            +
                    self.to("cpu")
         | 
| 448 | 
            +
                    return ret
         | 
| 449 | 
            +
             | 
| 450 | 
            +
             | 
| 451 | 
            +
             | 
| 452 | 
            +
            class StyleContentMetric(nn.Module):
         | 
| 453 | 
            +
                def __init__(self, style_ref_image_folder, device="cuda"):
         | 
| 454 | 
            +
                    super(StyleContentMetric, self).__init__()
         | 
| 455 | 
            +
                    self.device = device
         | 
| 456 | 
            +
                    self.clip_style_metric = CSD_CLIP(device=device)
         | 
| 457 | 
            +
                    self.ref_image_file = os.listdir(style_ref_image_folder)
         | 
| 458 | 
            +
                    self.ref_image_file = [i for i in self.ref_image_file if i.endswith(".jpg") or i.endswith(".png")]
         | 
| 459 | 
            +
                    self.ref_image_file.sort()
         | 
| 460 | 
            +
                    self.ref_image_file = np.array(self.ref_image_file)
         | 
| 461 | 
            +
                    ref_image_file_path = [os.path.join(style_ref_image_folder, i) for i in self.ref_image_file]
         | 
| 462 | 
            +
             | 
| 463 | 
            +
                    self.clip_style_metric.define_ref_image_style_prototype(ref_image_file_path)
         | 
| 464 | 
            +
                    self.vit_metric = Vit_metric(device=device)
         | 
| 465 | 
            +
                    self.vit_metric.define_ref_image_style_prototype(ref_image_file_path)
         | 
| 466 | 
            +
                    self.lpips_metric = LPIPS_metric(device=device)
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                    self.clip_content_metric = Clip_metric(alpha=0, target_style_prompt=None)
         | 
| 469 | 
            +
             | 
| 470 | 
            +
                    self.to("cpu")
         | 
| 471 | 
            +
             | 
| 472 | 
            +
                def forward(self, styled_data, original_data=None, content_caption=None):
         | 
| 473 | 
            +
                    ret ={}
         | 
| 474 | 
            +
                    csd_score = self.clip_style_metric(styled_data)
         | 
| 475 | 
            +
                    csd_score["max_query"] = self.ref_image_file[csd_score["CSD_similarity_max_id_details"].cpu()].tolist()
         | 
| 476 | 
            +
                    torch.cuda.empty_cache()
         | 
| 477 | 
            +
                    ret["Style_CSD"] = csd_score
         | 
| 478 | 
            +
                    vit_score = self.vit_metric(styled_data, original_data)
         | 
| 479 | 
            +
                    torch.cuda.empty_cache()
         | 
| 480 | 
            +
                    vit_style = vit_score["Vit_style"]
         | 
| 481 | 
            +
                    vit_style["max_query"] = self.ref_image_file[vit_style["Vit_style_similarity_max_id_details"].cpu()].tolist()
         | 
| 482 | 
            +
                    ret["Style_VIT"] = vit_style
         | 
| 483 | 
            +
             | 
| 484 | 
            +
                    if original_data is not None:
         | 
| 485 | 
            +
                        vit_content = vit_score["Vit_content"]
         | 
| 486 | 
            +
                        ret["Content_VIT"] = vit_content
         | 
| 487 | 
            +
                        lpips_score = self.lpips_metric(styled_data, original_data)
         | 
| 488 | 
            +
                        torch.cuda.empty_cache()
         | 
| 489 | 
            +
                        ret["Content_LPIPS"] = lpips_score
         | 
| 490 | 
            +
             | 
| 491 | 
            +
                    if content_caption is not None:
         | 
| 492 | 
            +
                        clip_content = self.clip_content_metric.content_score(styled_data, content_caption)
         | 
| 493 | 
            +
                        ret["Content_CLIP"] = clip_content
         | 
| 494 | 
            +
                        torch.cuda.empty_cache()
         | 
| 495 | 
            +
             | 
| 496 | 
            +
                    for type_key, type_value in ret.items():
         | 
| 497 | 
            +
                        for key, value in type_value.items():
         | 
| 498 | 
            +
                            if isinstance(value, torch.Tensor):
         | 
| 499 | 
            +
                                if value.numel() == 1:
         | 
| 500 | 
            +
                                    ret[type_key][key] = round(value.item(), 4)
         | 
| 501 | 
            +
                                else:
         | 
| 502 | 
            +
                                    ret[type_key][key] = value.tolist()
         | 
| 503 | 
            +
                                    ret[type_key][key] = [round(v, 4) for v in ret[type_key][key]]
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    self.to("cpu")
         | 
| 506 | 
            +
                    ret["ref_image_file"] = self.ref_image_file.tolist()
         | 
| 507 | 
            +
                    return ret
         | 
| 508 | 
            +
             | 
| 509 | 
            +
             | 
| 510 | 
            +
            if __name__ == "__main__":
         | 
| 511 | 
            +
                with torch.no_grad():
         | 
| 512 | 
            +
                    metric = StyleContentMetric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/clip_dissection/Art_styles/camille-pissarro/impressionism/split_5/paintings")
         | 
| 513 | 
            +
                    score = metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500",
         | 
| 514 | 
            +
                                   "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
         | 
| 515 | 
            +
                    print(score)
         | 
| 516 | 
            +
             | 
| 517 | 
            +
             | 
| 518 | 
            +
             | 
| 519 | 
            +
                    lpips = LPIPS_metric()
         | 
| 520 | 
            +
                    score = lpips("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
         | 
| 521 | 
            +
                                  "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                    print("lpips", score)
         | 
| 524 | 
            +
             | 
| 525 | 
            +
             | 
| 526 | 
            +
                    clip_metric = CSD_CLIP()
         | 
| 527 | 
            +
                    clip_metric.define_ref_image_style_prototype(
         | 
| 528 | 
            +
                        "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    score = clip_metric(
         | 
| 531 | 
            +
                        "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
         | 
| 532 | 
            +
                    print("subset3-subset3_sd14_converted", score)
         | 
| 533 | 
            +
             | 
| 534 | 
            +
                    score = clip_metric(
         | 
| 535 | 
            +
                        "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
         | 
| 536 | 
            +
                    print("subset3-photo", score)
         | 
| 537 | 
            +
             | 
| 538 | 
            +
             | 
| 539 | 
            +
             | 
| 540 | 
            +
                    score = clip_metric(
         | 
| 541 | 
            +
                        "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset1/paintings")
         | 
| 542 | 
            +
                    print("subset3-subset1", score)
         | 
| 543 | 
            +
             | 
| 544 | 
            +
                    score = clip_metric(
         | 
| 545 | 
            +
                        "/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/andy-warhol/pop_art/subset1/paintings")
         | 
| 546 | 
            +
                    print("subset3-andy", score)
         | 
| 547 | 
            +
                    # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings", "A painting")
         | 
| 548 | 
            +
             | 
| 549 | 
            +
                    # print("subset3",score)
         | 
| 550 | 
            +
                    # score_subset2 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset2/paintings")
         | 
| 551 | 
            +
                    # print("subset2",score_subset2)
         | 
| 552 | 
            +
                    # score_subset3 = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings")
         | 
| 553 | 
            +
                    # print("subset3",score_subset3)
         | 
| 554 | 
            +
                    #
         | 
| 555 | 
            +
                    # score_subset3_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/converted_photo/500")
         | 
| 556 | 
            +
                    # print("subset3-subset3_sd14_converted" , score_subset3_converted)
         | 
| 557 | 
            +
                    #
         | 
| 558 | 
            +
                    # score_subset3_coco_converted = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/coco_converted_photo/500")
         | 
| 559 | 
            +
                    # print("subset3-subset3_coco_converted" , score_subset3_coco_converted)
         | 
| 560 | 
            +
                    #
         | 
| 561 | 
            +
                    # clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/sketch_500")
         | 
| 562 | 
            +
                    # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
         | 
| 563 | 
            +
                    # print("photo500_1-sketch" ,score)
         | 
| 564 | 
            +
                    #
         | 
| 565 | 
            +
                    # clip_metric = Clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500")
         | 
| 566 | 
            +
                    # score = clip_metric("/afs/csail.mit.edu/u/h/huiren/code/diffusion/stable_diffusion/imgFolder/clip_filtered_remain_500_new")
         | 
| 567 | 
            +
                    # print("photo500_1-photo500_2" ,score)
         | 
| 568 | 
            +
                    # from custom_datasets.imagepair import ImageSet
         | 
| 569 | 
            +
                    # import matplotlib.pyplot as plt
         | 
| 570 | 
            +
                    # dataset = ImageSet(folder = "/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/paintings",
         | 
| 571 | 
            +
                    #                    caption_path="/data/vision/torralba/scratch/huiren/code/diffusion/stable_diffusion/custom_datasets/wikiart/data/gustav-klimt_Art_Nouveau/subset3/captions",
         | 
| 572 | 
            +
                    #                     keep_in_mem=False)
         | 
| 573 | 
            +
                    # for sample in dataset:
         | 
| 574 | 
            +
                    #     score = clip_metric.content_score(sample["image"], sample["caption"][0])
         | 
| 575 | 
            +
                    #     plt.imshow(sample["image"])
         | 
| 576 | 
            +
                    #     plt.title(f"score: {round(score.item(),2)}\n prompt: {sample['caption'][0]}")
         | 
| 577 | 
            +
                    #     plt.show()
         | 
    	
        utils/model_util.py
    ADDED
    
    | @@ -0,0 +1,291 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Literal, Union, Optional
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
         | 
| 5 | 
            +
            from diffusers import (
         | 
| 6 | 
            +
                UNet2DConditionModel,
         | 
| 7 | 
            +
                SchedulerMixin,
         | 
| 8 | 
            +
                StableDiffusionPipeline,
         | 
| 9 | 
            +
                StableDiffusionXLPipeline,
         | 
| 10 | 
            +
                AutoencoderKL,
         | 
| 11 | 
            +
            )
         | 
| 12 | 
            +
            from diffusers.schedulers import (
         | 
| 13 | 
            +
                DDIMScheduler,
         | 
| 14 | 
            +
                DDPMScheduler,
         | 
| 15 | 
            +
                LMSDiscreteScheduler,
         | 
| 16 | 
            +
                EulerAncestralDiscreteScheduler,
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            TOKENIZER_V1_MODEL_NAME = "CompVis/stable-diffusion-v1-4"
         | 
| 21 | 
            +
            TOKENIZER_V2_MODEL_NAME = "stabilityai/stable-diffusion-2-1"
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            AVAILABLE_SCHEDULERS = Literal["ddim", "ddpm", "lms", "euler_a"]
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            DIFFUSERS_CACHE_DIR = None  # if you want to change the cache dir, change this
         | 
| 28 | 
            +
            from diffusers.training_utils import EMAModel
         | 
| 29 | 
            +
            import os
         | 
| 30 | 
            +
            import sys
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            # from utils.modules import get_diffusion_modules
         | 
| 33 | 
            +
            def load_diffusers_model(
         | 
| 34 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 35 | 
            +
                v2: bool = False,
         | 
| 36 | 
            +
                clip_skip: Optional[int] = None,
         | 
| 37 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 38 | 
            +
            ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
         | 
| 39 | 
            +
                # VAE はいらない
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                if v2:
         | 
| 42 | 
            +
                    tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 43 | 
            +
                        TOKENIZER_V2_MODEL_NAME,
         | 
| 44 | 
            +
                        subfolder="tokenizer",
         | 
| 45 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 46 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                    text_encoder = CLIPTextModel.from_pretrained(
         | 
| 49 | 
            +
                        pretrained_model_name_or_path,
         | 
| 50 | 
            +
                        subfolder="text_encoder",
         | 
| 51 | 
            +
                        # default is clip skip 2
         | 
| 52 | 
            +
                        num_hidden_layers=24 - (clip_skip - 1) if clip_skip is not None else 23,
         | 
| 53 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 54 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 55 | 
            +
                    )
         | 
| 56 | 
            +
                else:
         | 
| 57 | 
            +
                    tokenizer = CLIPTokenizer.from_pretrained(
         | 
| 58 | 
            +
                        TOKENIZER_V1_MODEL_NAME,
         | 
| 59 | 
            +
                        subfolder="tokenizer",
         | 
| 60 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 61 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 62 | 
            +
                    )
         | 
| 63 | 
            +
                    text_encoder = CLIPTextModel.from_pretrained(
         | 
| 64 | 
            +
                        pretrained_model_name_or_path,
         | 
| 65 | 
            +
                        subfolder="text_encoder",
         | 
| 66 | 
            +
                        num_hidden_layers=12 - (clip_skip - 1) if clip_skip is not None else 12,
         | 
| 67 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 68 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
                unet = UNet2DConditionModel.from_pretrained(
         | 
| 72 | 
            +
                    pretrained_model_name_or_path,
         | 
| 73 | 
            +
                    subfolder="unet",
         | 
| 74 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 75 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 76 | 
            +
                )
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                return tokenizer, text_encoder, unet, vae
         | 
| 81 | 
            +
             | 
| 82 | 
            +
             | 
| 83 | 
            +
            def load_checkpoint_model(
         | 
| 84 | 
            +
                checkpoint_path: str,
         | 
| 85 | 
            +
                v2: bool = False,
         | 
| 86 | 
            +
                clip_skip: Optional[int] = None,
         | 
| 87 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 88 | 
            +
            ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel,]:
         | 
| 89 | 
            +
                pipe = StableDiffusionPipeline.from_ckpt(
         | 
| 90 | 
            +
                    checkpoint_path,
         | 
| 91 | 
            +
                    upcast_attention=True if v2 else False,
         | 
| 92 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 93 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 94 | 
            +
                )
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                unet = pipe.unet
         | 
| 97 | 
            +
                tokenizer = pipe.tokenizer
         | 
| 98 | 
            +
                text_encoder = pipe.text_encoder
         | 
| 99 | 
            +
                vae = pipe.vae
         | 
| 100 | 
            +
                if clip_skip is not None:
         | 
| 101 | 
            +
                    if v2:
         | 
| 102 | 
            +
                        text_encoder.config.num_hidden_layers = 24 - (clip_skip - 1)
         | 
| 103 | 
            +
                    else:
         | 
| 104 | 
            +
                        text_encoder.config.num_hidden_layers = 12 - (clip_skip - 1)
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                del pipe
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                return tokenizer, text_encoder, unet, vae
         | 
| 109 | 
            +
             | 
| 110 | 
            +
             | 
| 111 | 
            +
            def load_models(
         | 
| 112 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 113 | 
            +
                ckpt_path: str,
         | 
| 114 | 
            +
                scheduler_name: AVAILABLE_SCHEDULERS,
         | 
| 115 | 
            +
                v2: bool = False,
         | 
| 116 | 
            +
                v_pred: bool = False,
         | 
| 117 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 118 | 
            +
            ) -> tuple[CLIPTokenizer, CLIPTextModel, UNet2DConditionModel, SchedulerMixin,]:
         | 
| 119 | 
            +
                if pretrained_model_name_or_path.endswith(
         | 
| 120 | 
            +
                    ".ckpt"
         | 
| 121 | 
            +
                ) or pretrained_model_name_or_path.endswith(".safetensors"):
         | 
| 122 | 
            +
                    tokenizer, text_encoder, unet, vae = load_checkpoint_model(
         | 
| 123 | 
            +
                        pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
         | 
| 124 | 
            +
                    )
         | 
| 125 | 
            +
                else:  # diffusers
         | 
| 126 | 
            +
                    tokenizer, text_encoder, unet, vae = load_diffusers_model(
         | 
| 127 | 
            +
                        pretrained_model_name_or_path, v2=v2, weight_dtype=weight_dtype
         | 
| 128 | 
            +
                    )
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                # VAE はいらない
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                scheduler = create_noise_scheduler(
         | 
| 133 | 
            +
                    scheduler_name,
         | 
| 134 | 
            +
                    prediction_type="v_prediction" if v_pred else "epsilon",
         | 
| 135 | 
            +
                )
         | 
| 136 | 
            +
                # trained unet_ema
         | 
| 137 | 
            +
                if ckpt_path not in [None, "None"]:
         | 
| 138 | 
            +
                    ema_unet = EMAModel.from_pretrained(os.path.join(ckpt_path, "unet_ema"), UNet2DConditionModel)
         | 
| 139 | 
            +
                    ema_unet.copy_to(unet.parameters())
         | 
| 140 | 
            +
                return tokenizer, text_encoder, unet, scheduler, vae
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def load_diffusers_model_xl(
         | 
| 144 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 145 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 146 | 
            +
            ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
         | 
| 147 | 
            +
                # returns tokenizer, tokenizer_2, text_encoder, text_encoder_2, unet
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                tokenizers = [
         | 
| 150 | 
            +
                    CLIPTokenizer.from_pretrained(
         | 
| 151 | 
            +
                        pretrained_model_name_or_path,
         | 
| 152 | 
            +
                        subfolder="tokenizer",
         | 
| 153 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 154 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 155 | 
            +
                    ),
         | 
| 156 | 
            +
                    CLIPTokenizer.from_pretrained(
         | 
| 157 | 
            +
                        pretrained_model_name_or_path,
         | 
| 158 | 
            +
                        subfolder="tokenizer_2",
         | 
| 159 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 160 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 161 | 
            +
                        pad_token_id=0,  # same as open clip
         | 
| 162 | 
            +
                    ),
         | 
| 163 | 
            +
                ]
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                text_encoders = [
         | 
| 166 | 
            +
                    CLIPTextModel.from_pretrained(
         | 
| 167 | 
            +
                        pretrained_model_name_or_path,
         | 
| 168 | 
            +
                        subfolder="text_encoder",
         | 
| 169 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 170 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 171 | 
            +
                    ),
         | 
| 172 | 
            +
                    CLIPTextModelWithProjection.from_pretrained(
         | 
| 173 | 
            +
                        pretrained_model_name_or_path,
         | 
| 174 | 
            +
                        subfolder="text_encoder_2",
         | 
| 175 | 
            +
                        torch_dtype=weight_dtype,
         | 
| 176 | 
            +
                        cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 177 | 
            +
                    ),
         | 
| 178 | 
            +
                ]
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                unet = UNet2DConditionModel.from_pretrained(
         | 
| 181 | 
            +
                    pretrained_model_name_or_path,
         | 
| 182 | 
            +
                    subfolder="unet",
         | 
| 183 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 184 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 185 | 
            +
                )
         | 
| 186 | 
            +
                vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
         | 
| 187 | 
            +
                return tokenizers, text_encoders, unet, vae
         | 
| 188 | 
            +
             | 
| 189 | 
            +
             | 
| 190 | 
            +
            def load_checkpoint_model_xl(
         | 
| 191 | 
            +
                checkpoint_path: str,
         | 
| 192 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 193 | 
            +
            ) -> tuple[list[CLIPTokenizer], list[SDXL_TEXT_ENCODER_TYPE], UNet2DConditionModel,]:
         | 
| 194 | 
            +
                pipe = StableDiffusionXLPipeline.from_single_file(
         | 
| 195 | 
            +
                    checkpoint_path,
         | 
| 196 | 
            +
                    torch_dtype=weight_dtype,
         | 
| 197 | 
            +
                    cache_dir=DIFFUSERS_CACHE_DIR,
         | 
| 198 | 
            +
                )
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                unet = pipe.unet
         | 
| 201 | 
            +
                tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
         | 
| 202 | 
            +
                text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
         | 
| 203 | 
            +
                if len(text_encoders) == 2:
         | 
| 204 | 
            +
                    text_encoders[1].pad_token_id = 0
         | 
| 205 | 
            +
                vae = pipe.vae
         | 
| 206 | 
            +
                del pipe
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                return tokenizers, text_encoders, unet, vae
         | 
| 209 | 
            +
             | 
| 210 | 
            +
             | 
| 211 | 
            +
            def load_models_xl(
         | 
| 212 | 
            +
                pretrained_model_name_or_path: str,
         | 
| 213 | 
            +
                scheduler_name: AVAILABLE_SCHEDULERS,
         | 
| 214 | 
            +
                weight_dtype: torch.dtype = torch.float32,
         | 
| 215 | 
            +
            ) -> tuple[
         | 
| 216 | 
            +
                list[CLIPTokenizer],
         | 
| 217 | 
            +
                list[SDXL_TEXT_ENCODER_TYPE],
         | 
| 218 | 
            +
                UNet2DConditionModel,
         | 
| 219 | 
            +
                SchedulerMixin,
         | 
| 220 | 
            +
            ]:
         | 
| 221 | 
            +
                if pretrained_model_name_or_path.endswith(
         | 
| 222 | 
            +
                    ".ckpt"
         | 
| 223 | 
            +
                ) or pretrained_model_name_or_path.endswith(".safetensors"):
         | 
| 224 | 
            +
                    (
         | 
| 225 | 
            +
                        tokenizers,
         | 
| 226 | 
            +
                        text_encoders,
         | 
| 227 | 
            +
                        unet,
         | 
| 228 | 
            +
                        vae
         | 
| 229 | 
            +
                    ) = load_checkpoint_model_xl(pretrained_model_name_or_path, weight_dtype)
         | 
| 230 | 
            +
                else:  # diffusers
         | 
| 231 | 
            +
                    (
         | 
| 232 | 
            +
                        tokenizers,
         | 
| 233 | 
            +
                        text_encoders,
         | 
| 234 | 
            +
                        unet,
         | 
| 235 | 
            +
                        vae
         | 
| 236 | 
            +
                    ) = load_diffusers_model_xl(pretrained_model_name_or_path, weight_dtype)
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                scheduler = create_noise_scheduler(scheduler_name)
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                return tokenizers, text_encoders, unet, scheduler, vae
         | 
| 241 | 
            +
             | 
| 242 | 
            +
             | 
| 243 | 
            +
            def create_noise_scheduler(
         | 
| 244 | 
            +
                scheduler_name: AVAILABLE_SCHEDULERS = "ddpm",
         | 
| 245 | 
            +
                prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
         | 
| 246 | 
            +
            ) -> SchedulerMixin:
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
                name = scheduler_name.lower().replace(" ", "_")
         | 
| 250 | 
            +
                if name == "ddim":
         | 
| 251 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddim
         | 
| 252 | 
            +
                    scheduler = DDIMScheduler(
         | 
| 253 | 
            +
                        beta_start=0.00085,
         | 
| 254 | 
            +
                        beta_end=0.012,
         | 
| 255 | 
            +
                        beta_schedule="scaled_linear",
         | 
| 256 | 
            +
                        num_train_timesteps=1000,
         | 
| 257 | 
            +
                        clip_sample=False,
         | 
| 258 | 
            +
                        prediction_type=prediction_type,  # これでいいの?
         | 
| 259 | 
            +
                    )
         | 
| 260 | 
            +
                elif name == "ddpm":
         | 
| 261 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/ddpm
         | 
| 262 | 
            +
                    scheduler = DDPMScheduler(
         | 
| 263 | 
            +
                        beta_start=0.00085,
         | 
| 264 | 
            +
                        beta_end=0.012,
         | 
| 265 | 
            +
                        beta_schedule="scaled_linear",
         | 
| 266 | 
            +
                        num_train_timesteps=1000,
         | 
| 267 | 
            +
                        clip_sample=False,
         | 
| 268 | 
            +
                        prediction_type=prediction_type,
         | 
| 269 | 
            +
                    )
         | 
| 270 | 
            +
                elif name == "lms":
         | 
| 271 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/lms_discrete
         | 
| 272 | 
            +
                    scheduler = LMSDiscreteScheduler(
         | 
| 273 | 
            +
                        beta_start=0.00085,
         | 
| 274 | 
            +
                        beta_end=0.012,
         | 
| 275 | 
            +
                        beta_schedule="scaled_linear",
         | 
| 276 | 
            +
                        num_train_timesteps=1000,
         | 
| 277 | 
            +
                        prediction_type=prediction_type,
         | 
| 278 | 
            +
                    )
         | 
| 279 | 
            +
                elif name == "euler_a":
         | 
| 280 | 
            +
                    # https://huggingface.co/docs/diffusers/v0.17.1/en/api/schedulers/euler_ancestral
         | 
| 281 | 
            +
                    scheduler = EulerAncestralDiscreteScheduler(
         | 
| 282 | 
            +
                        beta_start=0.00085,
         | 
| 283 | 
            +
                        beta_end=0.012,
         | 
| 284 | 
            +
                        beta_schedule="scaled_linear",
         | 
| 285 | 
            +
                        num_train_timesteps=1000,
         | 
| 286 | 
            +
                        prediction_type=prediction_type,
         | 
| 287 | 
            +
                    )
         | 
| 288 | 
            +
                else:
         | 
| 289 | 
            +
                    raise ValueError(f"Unknown scheduler name: {name}")
         | 
| 290 | 
            +
             | 
| 291 | 
            +
                return scheduler
         | 
    	
        utils/prompt_util.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Literal, Optional, Union, List
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import yaml
         | 
| 4 | 
            +
            from pathlib import Path
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            from pydantic import BaseModel, root_validator
         | 
| 8 | 
            +
            import torch
         | 
| 9 | 
            +
            import copy
         | 
| 10 | 
            +
             | 
| 11 | 
            +
            ACTION_TYPES = Literal[
         | 
| 12 | 
            +
                "erase",
         | 
| 13 | 
            +
                "enhance",
         | 
| 14 | 
            +
            ]
         | 
| 15 | 
            +
             | 
| 16 | 
            +
             | 
| 17 | 
            +
            # XL は二種類必要なので
         | 
| 18 | 
            +
            class PromptEmbedsXL:
         | 
| 19 | 
            +
                text_embeds: torch.FloatTensor
         | 
| 20 | 
            +
                pooled_embeds: torch.FloatTensor
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(self, *args) -> None:
         | 
| 23 | 
            +
                    self.text_embeds = args[0]
         | 
| 24 | 
            +
                    self.pooled_embeds = args[1]
         | 
| 25 | 
            +
             | 
| 26 | 
            +
             | 
| 27 | 
            +
            # SDv1.x, SDv2.x は FloatTensor、XL は PromptEmbedsXL
         | 
| 28 | 
            +
            PROMPT_EMBEDDING = Union[torch.FloatTensor, PromptEmbedsXL]
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             | 
| 31 | 
            +
            class PromptEmbedsCache:  # 使いまわしたいので
         | 
| 32 | 
            +
                prompts: dict[str, PROMPT_EMBEDDING] = {}
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def __setitem__(self, __name: str, __value: PROMPT_EMBEDDING) -> None:
         | 
| 35 | 
            +
                    self.prompts[__name] = __value
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __getitem__(self, __name: str) -> Optional[PROMPT_EMBEDDING]:
         | 
| 38 | 
            +
                    if __name in self.prompts:
         | 
| 39 | 
            +
                        return self.prompts[__name]
         | 
| 40 | 
            +
                    else:
         | 
| 41 | 
            +
                        return None
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class PromptSettings(BaseModel):  # yaml のやつ
         | 
| 45 | 
            +
                target: str
         | 
| 46 | 
            +
                positive: str = None   # if None, target will be used
         | 
| 47 | 
            +
                unconditional: str = ""  # default is ""
         | 
| 48 | 
            +
                neutral: str = None  # if None, unconditional will be used
         | 
| 49 | 
            +
                action: ACTION_TYPES = "erase"  # default is "erase"
         | 
| 50 | 
            +
                guidance_scale: float = 1.0  # default is 1.0
         | 
| 51 | 
            +
                resolution: int = 512  # default is 512
         | 
| 52 | 
            +
                dynamic_resolution: bool = False  # default is False
         | 
| 53 | 
            +
                batch_size: int = 1  # default is 1
         | 
| 54 | 
            +
                dynamic_crops: bool = False  # default is False. only used when model is XL
         | 
| 55 | 
            +
             | 
| 56 | 
            +
                @root_validator(pre=True)
         | 
| 57 | 
            +
                def fill_prompts(cls, values):
         | 
| 58 | 
            +
                    keys = values.keys()
         | 
| 59 | 
            +
                    if "target" not in keys:
         | 
| 60 | 
            +
                        raise ValueError("target must be specified")
         | 
| 61 | 
            +
                    if "positive" not in keys:
         | 
| 62 | 
            +
                        values["positive"] = values["target"]
         | 
| 63 | 
            +
                    if "unconditional" not in keys:
         | 
| 64 | 
            +
                        values["unconditional"] = ""
         | 
| 65 | 
            +
                    if "neutral" not in keys:
         | 
| 66 | 
            +
                        values["neutral"] = values["unconditional"]
         | 
| 67 | 
            +
             | 
| 68 | 
            +
                    return values
         | 
| 69 | 
            +
             | 
| 70 | 
            +
             | 
| 71 | 
            +
            class PromptEmbedsPair:
         | 
| 72 | 
            +
                target: PROMPT_EMBEDDING  # not want to generate the concept
         | 
| 73 | 
            +
                positive: PROMPT_EMBEDDING  # generate the concept
         | 
| 74 | 
            +
                unconditional: PROMPT_EMBEDDING  # uncondition (default should be empty)
         | 
| 75 | 
            +
                neutral: PROMPT_EMBEDDING  # base condition (default should be empty)
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                guidance_scale: float
         | 
| 78 | 
            +
                resolution: int
         | 
| 79 | 
            +
                dynamic_resolution: bool
         | 
| 80 | 
            +
                batch_size: int
         | 
| 81 | 
            +
                dynamic_crops: bool
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                loss_fn: torch.nn.Module
         | 
| 84 | 
            +
                action: ACTION_TYPES
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def __init__(
         | 
| 87 | 
            +
                    self,
         | 
| 88 | 
            +
                    loss_fn: torch.nn.Module,
         | 
| 89 | 
            +
                    target: PROMPT_EMBEDDING,
         | 
| 90 | 
            +
                    positive: PROMPT_EMBEDDING,
         | 
| 91 | 
            +
                    unconditional: PROMPT_EMBEDDING,
         | 
| 92 | 
            +
                    neutral: PROMPT_EMBEDDING,
         | 
| 93 | 
            +
                    settings: PromptSettings,
         | 
| 94 | 
            +
                ) -> None:
         | 
| 95 | 
            +
                    self.loss_fn = loss_fn
         | 
| 96 | 
            +
                    self.target = target
         | 
| 97 | 
            +
                    self.positive = positive
         | 
| 98 | 
            +
                    self.unconditional = unconditional
         | 
| 99 | 
            +
                    self.neutral = neutral
         | 
| 100 | 
            +
                    
         | 
| 101 | 
            +
                    self.guidance_scale = settings.guidance_scale
         | 
| 102 | 
            +
                    self.resolution = settings.resolution
         | 
| 103 | 
            +
                    self.dynamic_resolution = settings.dynamic_resolution
         | 
| 104 | 
            +
                    self.batch_size = settings.batch_size
         | 
| 105 | 
            +
                    self.dynamic_crops = settings.dynamic_crops
         | 
| 106 | 
            +
                    self.action = settings.action
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                def _erase(
         | 
| 109 | 
            +
                    self,
         | 
| 110 | 
            +
                    target_latents: torch.FloatTensor,  # "van gogh"
         | 
| 111 | 
            +
                    positive_latents: torch.FloatTensor,  # "van gogh"
         | 
| 112 | 
            +
                    unconditional_latents: torch.FloatTensor,  # ""
         | 
| 113 | 
            +
                    neutral_latents: torch.FloatTensor,  # ""
         | 
| 114 | 
            +
                ) -> torch.FloatTensor:
         | 
| 115 | 
            +
                    """Target latents are going not to have the positive concept."""
         | 
| 116 | 
            +
                    return self.loss_fn(
         | 
| 117 | 
            +
                        target_latents,
         | 
| 118 | 
            +
                        neutral_latents
         | 
| 119 | 
            +
                        - self.guidance_scale * (positive_latents - unconditional_latents)
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def _enhance(
         | 
| 124 | 
            +
                    self,
         | 
| 125 | 
            +
                    target_latents: torch.FloatTensor,  # "van gogh"
         | 
| 126 | 
            +
                    positive_latents: torch.FloatTensor,  # "van gogh"
         | 
| 127 | 
            +
                    unconditional_latents: torch.FloatTensor,  # ""
         | 
| 128 | 
            +
                    neutral_latents: torch.FloatTensor,  # ""
         | 
| 129 | 
            +
                ):
         | 
| 130 | 
            +
                    """Target latents are going to have the positive concept."""
         | 
| 131 | 
            +
                    return self.loss_fn(
         | 
| 132 | 
            +
                        target_latents,
         | 
| 133 | 
            +
                        neutral_latents
         | 
| 134 | 
            +
                        + self.guidance_scale * (positive_latents - unconditional_latents)
         | 
| 135 | 
            +
                    )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def loss(
         | 
| 138 | 
            +
                    self,
         | 
| 139 | 
            +
                    **kwargs,
         | 
| 140 | 
            +
                ):
         | 
| 141 | 
            +
                    if self.action == "erase":
         | 
| 142 | 
            +
                        return self._erase(**kwargs)
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    elif self.action == "enhance":
         | 
| 145 | 
            +
                        return self._enhance(**kwargs)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    else:
         | 
| 148 | 
            +
                        raise ValueError("action must be erase or enhance")
         | 
| 149 | 
            +
             | 
| 150 | 
            +
             | 
| 151 | 
            +
            def load_prompts_from_yaml(path, attributes = []):
         | 
| 152 | 
            +
                with open(path, "r") as f:
         | 
| 153 | 
            +
                    prompts = yaml.safe_load(f)
         | 
| 154 | 
            +
                print(prompts)    
         | 
| 155 | 
            +
                if len(prompts) == 0:
         | 
| 156 | 
            +
                    raise ValueError("prompts file is empty")
         | 
| 157 | 
            +
                if len(attributes)!=0:
         | 
| 158 | 
            +
                    newprompts = []
         | 
| 159 | 
            +
                    for i in range(len(prompts)):
         | 
| 160 | 
            +
                        for att in attributes:
         | 
| 161 | 
            +
                            copy_ = copy.deepcopy(prompts[i])
         | 
| 162 | 
            +
                            copy_['target'] = att + ' ' + copy_['target']
         | 
| 163 | 
            +
                            copy_['positive'] = att + ' ' + copy_['positive']
         | 
| 164 | 
            +
                            copy_['neutral'] = att + ' ' + copy_['neutral']
         | 
| 165 | 
            +
                            copy_['unconditional'] = att + ' ' + copy_['unconditional']
         | 
| 166 | 
            +
                            newprompts.append(copy_)
         | 
| 167 | 
            +
                else:
         | 
| 168 | 
            +
                    newprompts = copy.deepcopy(prompts)
         | 
| 169 | 
            +
                
         | 
| 170 | 
            +
                print(newprompts)
         | 
| 171 | 
            +
                print(len(prompts), len(newprompts))
         | 
| 172 | 
            +
                prompt_settings = [PromptSettings(**prompt) for prompt in newprompts]
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                return prompt_settings
         | 
    	
        utils/train_util.py
    ADDED
    
    | @@ -0,0 +1,526 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Optional, Union
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer, BertModel, BertTokenizer
         | 
| 6 | 
            +
            from diffusers import UNet2DConditionModel, SchedulerMixin
         | 
| 7 | 
            +
            from diffusers.image_processor import VaeImageProcessor
         | 
| 8 | 
            +
            import sys
         | 
| 9 | 
            +
            import os
         | 
| 10 | 
            +
            # sys.path.append(os.path.join(os.path.dirname(__file__), "../"))
         | 
| 11 | 
            +
            # from imagesliders.model_util import SDXL_TEXT_ENCODER_TYPE
         | 
| 12 | 
            +
            from diffusers.utils.torch_utils import randn_tensor
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            SDXL_TEXT_ENCODER_TYPE = Union[CLIPTextModel, CLIPTextModelWithProjection]
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            from tqdm import tqdm
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            UNET_IN_CHANNELS = 4  # Stable Diffusion  in_channels
         | 
| 21 | 
            +
            VAE_SCALE_FACTOR = 8  # 2 ** (len(vae.config.block_out_channels) - 1) = 8
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            UNET_ATTENTION_TIME_EMBED_DIM = 256  # XL
         | 
| 24 | 
            +
            TEXT_ENCODER_2_PROJECTION_DIM = 1280
         | 
| 25 | 
            +
            UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM = 2816
         | 
| 26 | 
            +
             | 
| 27 | 
            +
             | 
| 28 | 
            +
            def get_random_noise(
         | 
| 29 | 
            +
                batch_size: int, height: int, width: int, generator: torch.Generator = None
         | 
| 30 | 
            +
            ) -> torch.Tensor:
         | 
| 31 | 
            +
                return torch.randn(
         | 
| 32 | 
            +
                    (
         | 
| 33 | 
            +
                        batch_size,
         | 
| 34 | 
            +
                        UNET_IN_CHANNELS,
         | 
| 35 | 
            +
                        height // VAE_SCALE_FACTOR,
         | 
| 36 | 
            +
                        width // VAE_SCALE_FACTOR,
         | 
| 37 | 
            +
                    ),
         | 
| 38 | 
            +
                    generator=generator,
         | 
| 39 | 
            +
                    device="cpu",
         | 
| 40 | 
            +
                )
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            def apply_noise_offset(latents: torch.FloatTensor, noise_offset: float):
         | 
| 45 | 
            +
                latents = latents + noise_offset * torch.randn(
         | 
| 46 | 
            +
                    (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
         | 
| 47 | 
            +
                )
         | 
| 48 | 
            +
                return latents
         | 
| 49 | 
            +
             | 
| 50 | 
            +
             | 
| 51 | 
            +
            def get_initial_latents(
         | 
| 52 | 
            +
                scheduler: SchedulerMixin,
         | 
| 53 | 
            +
                n_imgs: int,
         | 
| 54 | 
            +
                height: int,
         | 
| 55 | 
            +
                width: int,
         | 
| 56 | 
            +
                n_prompts: int,
         | 
| 57 | 
            +
                generator=None,
         | 
| 58 | 
            +
            ) -> torch.Tensor:
         | 
| 59 | 
            +
                noise = get_random_noise(n_imgs, height, width, generator=generator).repeat(
         | 
| 60 | 
            +
                    n_prompts, 1, 1, 1
         | 
| 61 | 
            +
                )
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                latents = noise * scheduler.init_noise_sigma
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                return latents
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def text_tokenize(
         | 
| 69 | 
            +
                tokenizer,  # 普通ならひとつ、XLならふたつ!
         | 
| 70 | 
            +
                prompts,
         | 
| 71 | 
            +
            ):
         | 
| 72 | 
            +
                return tokenizer(
         | 
| 73 | 
            +
                    prompts,
         | 
| 74 | 
            +
                    padding="max_length",
         | 
| 75 | 
            +
                    max_length=tokenizer.model_max_length,
         | 
| 76 | 
            +
                    truncation=True,
         | 
| 77 | 
            +
                    return_tensors="pt",
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            def text_encode(text_encoder , tokens):
         | 
| 82 | 
            +
                tokens = tokens.to(text_encoder.device)
         | 
| 83 | 
            +
                if isinstance(text_encoder, BertModel):
         | 
| 84 | 
            +
                    embed = text_encoder(**tokens, return_dict=False)[0]
         | 
| 85 | 
            +
                elif isinstance(text_encoder, CLIPTextModel):
         | 
| 86 | 
            +
                    # embed = text_encoder(**tokens, return_dict=False)[0]
         | 
| 87 | 
            +
                    embed = text_encoder(tokens.input_ids, return_dict=False)[0]
         | 
| 88 | 
            +
                else:
         | 
| 89 | 
            +
                    raise ValueError("text_encoder must be BertModel or CLIPTextModel")
         | 
| 90 | 
            +
                return embed
         | 
| 91 | 
            +
             | 
| 92 | 
            +
            def encode_prompts(
         | 
| 93 | 
            +
                tokenizer,
         | 
| 94 | 
            +
                text_encoder,
         | 
| 95 | 
            +
                prompts: list[str],
         | 
| 96 | 
            +
            ):
         | 
| 97 | 
            +
                # print(f"prompts: {prompts}")
         | 
| 98 | 
            +
                text_tokens = text_tokenize(tokenizer, prompts)
         | 
| 99 | 
            +
                # print(f"text_tokens: {text_tokens}")
         | 
| 100 | 
            +
                text_embeddings = text_encode(text_encoder, text_tokens)
         | 
| 101 | 
            +
                # print(f"text_embeddings: {text_embeddings}")
         | 
| 102 | 
            +
                
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                return text_embeddings
         | 
| 105 | 
            +
             | 
| 106 | 
            +
            def prompt_replace(original, key="{prompt}", prompt=""):
         | 
| 107 | 
            +
                if key not in original:
         | 
| 108 | 
            +
                    return original
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                if isinstance(prompt, list):
         | 
| 111 | 
            +
                    ret =[]
         | 
| 112 | 
            +
                    for p in prompt:
         | 
| 113 | 
            +
                        p = p.replace(".", "")
         | 
| 114 | 
            +
                        r = original.replace(key, p)
         | 
| 115 | 
            +
                        r = r.capitalize()
         | 
| 116 | 
            +
                        ret.append(r)
         | 
| 117 | 
            +
                else:
         | 
| 118 | 
            +
                    prompt = prompt.replace(".", "")
         | 
| 119 | 
            +
                    ret = original.replace(key, prompt)
         | 
| 120 | 
            +
                    ret = ret.capitalize()
         | 
| 121 | 
            +
                return ret
         | 
| 122 | 
            +
             | 
| 123 | 
            +
             | 
| 124 | 
            +
             | 
| 125 | 
            +
            def text_encode_xl(
         | 
| 126 | 
            +
                text_encoder: SDXL_TEXT_ENCODER_TYPE,
         | 
| 127 | 
            +
                tokens: torch.FloatTensor,
         | 
| 128 | 
            +
                num_images_per_prompt: int = 1,
         | 
| 129 | 
            +
            ):
         | 
| 130 | 
            +
                prompt_embeds = text_encoder(
         | 
| 131 | 
            +
                    tokens.to(text_encoder.device), output_hidden_states=True
         | 
| 132 | 
            +
                )
         | 
| 133 | 
            +
                pooled_prompt_embeds = prompt_embeds[0]
         | 
| 134 | 
            +
                prompt_embeds = prompt_embeds.hidden_states[-2]  # always penultimate layer
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                bs_embed, seq_len, _ = prompt_embeds.shape
         | 
| 137 | 
            +
                prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
         | 
| 138 | 
            +
                prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                return prompt_embeds, pooled_prompt_embeds
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            def encode_prompts_xl(
         | 
| 144 | 
            +
                tokenizers: list[CLIPTokenizer],
         | 
| 145 | 
            +
                text_encoders: list[SDXL_TEXT_ENCODER_TYPE],
         | 
| 146 | 
            +
                prompts: list[str],
         | 
| 147 | 
            +
                num_images_per_prompt: int = 1,
         | 
| 148 | 
            +
            ) -> tuple[torch.FloatTensor, torch.FloatTensor]:
         | 
| 149 | 
            +
                # text_encoder and text_encoder_2's penuultimate layer's output
         | 
| 150 | 
            +
                text_embeds_list = []
         | 
| 151 | 
            +
                pooled_text_embeds = None  # always text_encoder_2's pool
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                for tokenizer, text_encoder in zip(tokenizers, text_encoders):
         | 
| 154 | 
            +
                    text_tokens_input_ids = text_tokenize(tokenizer, prompts)
         | 
| 155 | 
            +
                    text_embeds, pooled_text_embeds = text_encode_xl(
         | 
| 156 | 
            +
                        text_encoder, text_tokens_input_ids, num_images_per_prompt
         | 
| 157 | 
            +
                    )
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    text_embeds_list.append(text_embeds)
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                bs_embed = pooled_text_embeds.shape[0]
         | 
| 162 | 
            +
                pooled_text_embeds = pooled_text_embeds.repeat(1, num_images_per_prompt).view(
         | 
| 163 | 
            +
                    bs_embed * num_images_per_prompt, -1
         | 
| 164 | 
            +
                )
         | 
| 165 | 
            +
             | 
| 166 | 
            +
                return torch.concat(text_embeds_list, dim=-1), pooled_text_embeds
         | 
| 167 | 
            +
             | 
| 168 | 
            +
             | 
| 169 | 
            +
            def concat_embeddings(
         | 
| 170 | 
            +
                unconditional: torch.FloatTensor,
         | 
| 171 | 
            +
                conditional: torch.FloatTensor,
         | 
| 172 | 
            +
                n_imgs: int,
         | 
| 173 | 
            +
            ):
         | 
| 174 | 
            +
                if conditional.shape[0] == n_imgs and unconditional.shape[0] == 1:
         | 
| 175 | 
            +
                    return torch.cat([unconditional.repeat(n_imgs, 1, 1), conditional], dim=0)
         | 
| 176 | 
            +
                return torch.cat([unconditional, conditional]).repeat_interleave(n_imgs, dim=0)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
             | 
| 179 | 
            +
            def predict_noise(
         | 
| 180 | 
            +
                unet: UNet2DConditionModel,
         | 
| 181 | 
            +
                scheduler: SchedulerMixin,
         | 
| 182 | 
            +
                timestep: int,
         | 
| 183 | 
            +
                latents: torch.FloatTensor,
         | 
| 184 | 
            +
                text_embeddings: torch.FloatTensor,  # uncond な text embed と cond な text embed を結合したもの
         | 
| 185 | 
            +
                guidance_scale=7.5,
         | 
| 186 | 
            +
            ) -> torch.FloatTensor:
         | 
| 187 | 
            +
                # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
         | 
| 188 | 
            +
                latent_model_input = torch.cat([latents] * 2)
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
         | 
| 191 | 
            +
                # batch_size = latents.shape[0]
         | 
| 192 | 
            +
                # text_embeddings = text_embeddings.repeat_interleave(batch_size, dim=0)
         | 
| 193 | 
            +
                # predict the noise residual
         | 
| 194 | 
            +
                noise_pred = unet(
         | 
| 195 | 
            +
                    latent_model_input,
         | 
| 196 | 
            +
                    timestep,
         | 
| 197 | 
            +
                    encoder_hidden_states=text_embeddings,
         | 
| 198 | 
            +
                ).sample
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                # perform guidance
         | 
| 201 | 
            +
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 202 | 
            +
                guided_target = noise_pred_uncond + guidance_scale * (
         | 
| 203 | 
            +
                    noise_pred_text - noise_pred_uncond
         | 
| 204 | 
            +
                )
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                return guided_target
         | 
| 207 | 
            +
             | 
| 208 | 
            +
             | 
| 209 | 
            +
             | 
| 210 | 
            +
            @torch.no_grad()
         | 
| 211 | 
            +
            def diffusion(
         | 
| 212 | 
            +
                unet: UNet2DConditionModel,
         | 
| 213 | 
            +
                scheduler: SchedulerMixin,
         | 
| 214 | 
            +
                latents: torch.FloatTensor,
         | 
| 215 | 
            +
                text_embeddings: torch.FloatTensor,
         | 
| 216 | 
            +
                total_timesteps: int = 1000,
         | 
| 217 | 
            +
                start_timesteps=0,
         | 
| 218 | 
            +
                **kwargs,
         | 
| 219 | 
            +
            ):
         | 
| 220 | 
            +
                # latents_steps = []
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                for timestep in scheduler.timesteps[start_timesteps:total_timesteps]:
         | 
| 223 | 
            +
                    noise_pred = predict_noise(
         | 
| 224 | 
            +
                        unet, scheduler, timestep, latents, text_embeddings, **kwargs
         | 
| 225 | 
            +
                    )
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # compute the previous noisy sample x_t -> x_t-1
         | 
| 228 | 
            +
                    latents = scheduler.step(noise_pred, timestep, latents).prev_sample
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                # return latents_steps
         | 
| 231 | 
            +
                return latents
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            @torch.no_grad()
         | 
| 234 | 
            +
            def get_noisy_image(
         | 
| 235 | 
            +
                img,
         | 
| 236 | 
            +
                vae,
         | 
| 237 | 
            +
                generator,
         | 
| 238 | 
            +
                unet: UNet2DConditionModel,
         | 
| 239 | 
            +
                scheduler: SchedulerMixin,
         | 
| 240 | 
            +
                total_timesteps: int = 1000,
         | 
| 241 | 
            +
                start_timesteps=0,
         | 
| 242 | 
            +
                
         | 
| 243 | 
            +
                **kwargs,
         | 
| 244 | 
            +
            ):
         | 
| 245 | 
            +
                # latents_steps = []
         | 
| 246 | 
            +
                vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
         | 
| 247 | 
            +
                image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                image = img
         | 
| 250 | 
            +
                # im_orig = image
         | 
| 251 | 
            +
                device = vae.device
         | 
| 252 | 
            +
                image = image_processor.preprocess(image).to(device)
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                init_latents = vae.encode(image).latent_dist.sample(None)
         | 
| 255 | 
            +
                init_latents = vae.config.scaling_factor * init_latents
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                init_latents = torch.cat([init_latents], dim=0)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                shape = init_latents.shape
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                noise = randn_tensor(shape, generator=generator, device=device)
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                time_ = total_timesteps
         | 
| 264 | 
            +
                timestep = scheduler.timesteps[time_:time_+1]
         | 
| 265 | 
            +
                # get latents
         | 
| 266 | 
            +
                noised_latents = scheduler.add_noise(init_latents, noise, timestep)
         | 
| 267 | 
            +
                
         | 
| 268 | 
            +
                return noised_latents, noise, init_latents
         | 
| 269 | 
            +
             | 
| 270 | 
            +
            def subtract_noise(
         | 
| 271 | 
            +
                    latent: torch.FloatTensor,
         | 
| 272 | 
            +
                    noise: torch.FloatTensor,
         | 
| 273 | 
            +
                    timesteps: torch.IntTensor,
         | 
| 274 | 
            +
                    scheduler: SchedulerMixin,
         | 
| 275 | 
            +
            ) -> torch.FloatTensor:
         | 
| 276 | 
            +
                # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
         | 
| 277 | 
            +
                # Move the self.alphas_cumprod to device to avoid redundant CPU to GPU data movement
         | 
| 278 | 
            +
                # for the subsequent add_noise calls
         | 
| 279 | 
            +
                scheduler.alphas_cumprod = scheduler.alphas_cumprod.to(device=latent.device)
         | 
| 280 | 
            +
                alphas_cumprod = scheduler.alphas_cumprod.to(dtype=latent.dtype)
         | 
| 281 | 
            +
                timesteps = timesteps.to(latent.device)
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
         | 
| 284 | 
            +
                sqrt_alpha_prod = sqrt_alpha_prod.flatten()
         | 
| 285 | 
            +
                while len(sqrt_alpha_prod.shape) < len(latent.shape):
         | 
| 286 | 
            +
                    sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
         | 
| 289 | 
            +
                sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
         | 
| 290 | 
            +
                while len(sqrt_one_minus_alpha_prod.shape) < len(latent.shape):
         | 
| 291 | 
            +
                    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                denoised_latent =  (latent - sqrt_one_minus_alpha_prod * noise) / sqrt_alpha_prod
         | 
| 294 | 
            +
                return denoised_latent
         | 
| 295 | 
            +
            def get_denoised_image(
         | 
| 296 | 
            +
                    latents: torch.FloatTensor,
         | 
| 297 | 
            +
                    noise_pred: torch.FloatTensor,
         | 
| 298 | 
            +
                    timestep: int,
         | 
| 299 | 
            +
                    # total_timesteps: int,
         | 
| 300 | 
            +
                    scheduler: SchedulerMixin,
         | 
| 301 | 
            +
                    vae: VaeImageProcessor,
         | 
| 302 | 
            +
            ):
         | 
| 303 | 
            +
                denoised_latents = subtract_noise(latents, noise_pred, timestep, scheduler)
         | 
| 304 | 
            +
                denoised_latents = denoised_latents / vae.config.scaling_factor # 0.18215
         | 
| 305 | 
            +
                denoised_img = vae.decode(denoised_latents).sample
         | 
| 306 | 
            +
                # denoised_img = denoised_img.clamp(-1,1)
         | 
| 307 | 
            +
                return denoised_img
         | 
| 308 | 
            +
             | 
| 309 | 
            +
             | 
| 310 | 
            +
            def rescale_noise_cfg(
         | 
| 311 | 
            +
                noise_cfg: torch.FloatTensor, noise_pred_text, guidance_rescale=0.0
         | 
| 312 | 
            +
            ):
         | 
| 313 | 
            +
             | 
| 314 | 
            +
                std_text = noise_pred_text.std(
         | 
| 315 | 
            +
                    dim=list(range(1, noise_pred_text.ndim)), keepdim=True
         | 
| 316 | 
            +
                )
         | 
| 317 | 
            +
                std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
         | 
| 318 | 
            +
                # rescale the results from guidance (fixes overexposure)
         | 
| 319 | 
            +
                noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
         | 
| 320 | 
            +
                # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
         | 
| 321 | 
            +
                noise_cfg = (
         | 
| 322 | 
            +
                    guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
         | 
| 323 | 
            +
                )
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                return noise_cfg
         | 
| 326 | 
            +
             | 
| 327 | 
            +
             | 
| 328 | 
            +
            def predict_noise_xl(
         | 
| 329 | 
            +
                unet: UNet2DConditionModel,
         | 
| 330 | 
            +
                scheduler: SchedulerMixin,
         | 
| 331 | 
            +
                timestep: int,
         | 
| 332 | 
            +
                latents: torch.FloatTensor,
         | 
| 333 | 
            +
                text_embeddings: torch.FloatTensor,  # uncond な text embed と cond な text embed を結合したもの
         | 
| 334 | 
            +
                add_text_embeddings: torch.FloatTensor,  # pooled なやつ
         | 
| 335 | 
            +
                add_time_ids: torch.FloatTensor,
         | 
| 336 | 
            +
                guidance_scale=7.5,
         | 
| 337 | 
            +
                guidance_rescale=0.7,
         | 
| 338 | 
            +
            ) -> torch.FloatTensor:
         | 
| 339 | 
            +
                # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
         | 
| 340 | 
            +
                latent_model_input = torch.cat([latents] * 2)
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
                added_cond_kwargs = {
         | 
| 345 | 
            +
                    "text_embeds": add_text_embeddings,
         | 
| 346 | 
            +
                    "time_ids": add_time_ids,
         | 
| 347 | 
            +
                }
         | 
| 348 | 
            +
             | 
| 349 | 
            +
                # predict the noise residual
         | 
| 350 | 
            +
                noise_pred = unet(
         | 
| 351 | 
            +
                    latent_model_input,
         | 
| 352 | 
            +
                    timestep,
         | 
| 353 | 
            +
                    encoder_hidden_states=text_embeddings,
         | 
| 354 | 
            +
                    added_cond_kwargs=added_cond_kwargs,
         | 
| 355 | 
            +
                ).sample
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                # perform guidance
         | 
| 358 | 
            +
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
         | 
| 359 | 
            +
                guided_target = noise_pred_uncond + guidance_scale * (
         | 
| 360 | 
            +
                    noise_pred_text - noise_pred_uncond
         | 
| 361 | 
            +
                )
         | 
| 362 | 
            +
             | 
| 363 | 
            +
                noise_pred = rescale_noise_cfg(
         | 
| 364 | 
            +
                    noise_pred, noise_pred_text, guidance_rescale=guidance_rescale
         | 
| 365 | 
            +
                )
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                return guided_target
         | 
| 368 | 
            +
             | 
| 369 | 
            +
             | 
| 370 | 
            +
            @torch.no_grad()
         | 
| 371 | 
            +
            def diffusion_xl(
         | 
| 372 | 
            +
                unet: UNet2DConditionModel,
         | 
| 373 | 
            +
                scheduler: SchedulerMixin,
         | 
| 374 | 
            +
                latents: torch.FloatTensor,
         | 
| 375 | 
            +
                text_embeddings: tuple[torch.FloatTensor, torch.FloatTensor],
         | 
| 376 | 
            +
                add_text_embeddings: torch.FloatTensor,
         | 
| 377 | 
            +
                add_time_ids: torch.FloatTensor,
         | 
| 378 | 
            +
                guidance_scale: float = 1.0,
         | 
| 379 | 
            +
                total_timesteps: int = 1000,
         | 
| 380 | 
            +
                start_timesteps=0,
         | 
| 381 | 
            +
            ):
         | 
| 382 | 
            +
                # latents_steps = []
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                for timestep in tqdm(scheduler.timesteps[start_timesteps:total_timesteps]):
         | 
| 385 | 
            +
                    noise_pred = predict_noise_xl(
         | 
| 386 | 
            +
                        unet,
         | 
| 387 | 
            +
                        scheduler,
         | 
| 388 | 
            +
                        timestep,
         | 
| 389 | 
            +
                        latents,
         | 
| 390 | 
            +
                        text_embeddings,
         | 
| 391 | 
            +
                        add_text_embeddings,
         | 
| 392 | 
            +
                        add_time_ids,
         | 
| 393 | 
            +
                        guidance_scale=guidance_scale,
         | 
| 394 | 
            +
                        guidance_rescale=0.7,
         | 
| 395 | 
            +
                    )
         | 
| 396 | 
            +
             | 
| 397 | 
            +
                    # compute the previous noisy sample x_t -> x_t-1
         | 
| 398 | 
            +
                    latents = scheduler.step(noise_pred, timestep, latents).prev_sample
         | 
| 399 | 
            +
             | 
| 400 | 
            +
                # return latents_steps
         | 
| 401 | 
            +
                return latents
         | 
| 402 | 
            +
             | 
| 403 | 
            +
             | 
| 404 | 
            +
            # for XL
         | 
| 405 | 
            +
            def get_add_time_ids(
         | 
| 406 | 
            +
                height: int,
         | 
| 407 | 
            +
                width: int,
         | 
| 408 | 
            +
                dynamic_crops: bool = False,
         | 
| 409 | 
            +
                dtype: torch.dtype = torch.float32,
         | 
| 410 | 
            +
            ):
         | 
| 411 | 
            +
                if dynamic_crops:
         | 
| 412 | 
            +
                    # random float scale between 1 and 3
         | 
| 413 | 
            +
                    random_scale = torch.rand(1).item() * 2 + 1
         | 
| 414 | 
            +
                    original_size = (int(height * random_scale), int(width * random_scale))
         | 
| 415 | 
            +
                    # random position
         | 
| 416 | 
            +
                    crops_coords_top_left = (
         | 
| 417 | 
            +
                        torch.randint(0, original_size[0] - height, (1,)).item(),
         | 
| 418 | 
            +
                        torch.randint(0, original_size[1] - width, (1,)).item(),
         | 
| 419 | 
            +
                    )
         | 
| 420 | 
            +
                    target_size = (height, width)
         | 
| 421 | 
            +
                else:
         | 
| 422 | 
            +
                    original_size = (height, width)
         | 
| 423 | 
            +
                    crops_coords_top_left = (0, 0)
         | 
| 424 | 
            +
                    target_size = (height, width)
         | 
| 425 | 
            +
             | 
| 426 | 
            +
                # this is expected as 6
         | 
| 427 | 
            +
                add_time_ids = list(original_size + crops_coords_top_left + target_size)
         | 
| 428 | 
            +
             | 
| 429 | 
            +
                # this is expected as 2816
         | 
| 430 | 
            +
                passed_add_embed_dim = (
         | 
| 431 | 
            +
                    UNET_ATTENTION_TIME_EMBED_DIM * len(add_time_ids)  # 256 * 6
         | 
| 432 | 
            +
                    + TEXT_ENCODER_2_PROJECTION_DIM  # + 1280
         | 
| 433 | 
            +
                )
         | 
| 434 | 
            +
                if passed_add_embed_dim != UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM:
         | 
| 435 | 
            +
                    raise ValueError(
         | 
| 436 | 
            +
                        f"Model expects an added time embedding vector of length {UNET_PROJECTION_CLASS_EMBEDDING_INPUT_DIM}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
         | 
| 437 | 
            +
                    )
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
         | 
| 440 | 
            +
                return add_time_ids
         | 
| 441 | 
            +
             | 
| 442 | 
            +
             | 
| 443 | 
            +
            def get_optimizer(name: str):
         | 
| 444 | 
            +
                name = name.lower()
         | 
| 445 | 
            +
             | 
| 446 | 
            +
                if name.startswith("dadapt"):
         | 
| 447 | 
            +
                    import dadaptation
         | 
| 448 | 
            +
             | 
| 449 | 
            +
                    if name == "dadaptadam":
         | 
| 450 | 
            +
                        return dadaptation.DAdaptAdam
         | 
| 451 | 
            +
                    elif name == "dadaptlion":
         | 
| 452 | 
            +
                        return dadaptation.DAdaptLion
         | 
| 453 | 
            +
                    else:
         | 
| 454 | 
            +
                        raise ValueError("DAdapt optimizer must be dadaptadam or dadaptlion")
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                elif name.endswith("8bit"):
         | 
| 457 | 
            +
                    import bitsandbytes as bnb
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    if name == "adam8bit":
         | 
| 460 | 
            +
                        return bnb.optim.Adam8bit
         | 
| 461 | 
            +
                    elif name == "lion8bit":
         | 
| 462 | 
            +
                        return bnb.optim.Lion8bit
         | 
| 463 | 
            +
                    else:
         | 
| 464 | 
            +
                        raise ValueError("8bit optimizer must be adam8bit or lion8bit")
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                else:
         | 
| 467 | 
            +
                    if name == "adam":
         | 
| 468 | 
            +
                        return torch.optim.Adam
         | 
| 469 | 
            +
                    elif name == "adamw":
         | 
| 470 | 
            +
                        return torch.optim.AdamW
         | 
| 471 | 
            +
                    elif name == "lion":
         | 
| 472 | 
            +
                        from lion_pytorch import Lion
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                        return Lion
         | 
| 475 | 
            +
                    elif name == "prodigy":
         | 
| 476 | 
            +
                        import prodigyopt
         | 
| 477 | 
            +
                        
         | 
| 478 | 
            +
                        return prodigyopt.Prodigy
         | 
| 479 | 
            +
                    else:
         | 
| 480 | 
            +
                        raise ValueError("Optimizer must be adam, adamw, lion or Prodigy")
         | 
| 481 | 
            +
             | 
| 482 | 
            +
             | 
| 483 | 
            +
            def get_lr_scheduler(
         | 
| 484 | 
            +
                name: Optional[str],
         | 
| 485 | 
            +
                optimizer: torch.optim.Optimizer,
         | 
| 486 | 
            +
                max_iterations: Optional[int],
         | 
| 487 | 
            +
                lr_min: Optional[float],
         | 
| 488 | 
            +
                **kwargs,
         | 
| 489 | 
            +
            ):
         | 
| 490 | 
            +
                if name == "cosine":
         | 
| 491 | 
            +
                    return torch.optim.lr_scheduler.CosineAnnealingLR(
         | 
| 492 | 
            +
                        optimizer, T_max=max_iterations, eta_min=lr_min, **kwargs
         | 
| 493 | 
            +
                    )
         | 
| 494 | 
            +
                elif name == "cosine_with_restarts":
         | 
| 495 | 
            +
                    return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
         | 
| 496 | 
            +
                        optimizer, T_0=max_iterations // 10, T_mult=2, eta_min=lr_min, **kwargs
         | 
| 497 | 
            +
                    )
         | 
| 498 | 
            +
                elif name == "step":
         | 
| 499 | 
            +
                    return torch.optim.lr_scheduler.StepLR(
         | 
| 500 | 
            +
                        optimizer, step_size=max_iterations // 100, gamma=0.999, **kwargs
         | 
| 501 | 
            +
                    )
         | 
| 502 | 
            +
                elif name == "constant":
         | 
| 503 | 
            +
                    return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1, **kwargs)
         | 
| 504 | 
            +
                elif name == "linear":
         | 
| 505 | 
            +
                    return torch.optim.lr_scheduler.LinearLR(
         | 
| 506 | 
            +
                        optimizer, factor=0.5, total_iters=max_iterations // 100, **kwargs
         | 
| 507 | 
            +
                    )
         | 
| 508 | 
            +
                else:
         | 
| 509 | 
            +
                    raise ValueError(
         | 
| 510 | 
            +
                        "Scheduler must be cosine, cosine_with_restarts, step, linear or constant"
         | 
| 511 | 
            +
                    )
         | 
| 512 | 
            +
             | 
| 513 | 
            +
             | 
| 514 | 
            +
            def get_random_resolution_in_bucket(bucket_resolution: int = 512) -> tuple[int, int]:
         | 
| 515 | 
            +
                max_resolution = bucket_resolution
         | 
| 516 | 
            +
                min_resolution = bucket_resolution // 2
         | 
| 517 | 
            +
             | 
| 518 | 
            +
                step = 64
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                min_step = min_resolution // step
         | 
| 521 | 
            +
                max_step = max_resolution // step
         | 
| 522 | 
            +
             | 
| 523 | 
            +
                height = torch.randint(min_step, max_step, (1,)).item() * step
         | 
| 524 | 
            +
                width = torch.randint(min_step, max_step, (1,)).item() * step
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                return height, width
         |