{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "c524f796-e657-4a59-abcf-540531a38995", "metadata": { "tags": [] }, "outputs": [], "source": [ "%run get_parser.py" ] }, { "cell_type": "code", "execution_count": null, "id": "4c1cf01e-8229-4d28-bcb2-01c07fa641c2", "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "import requests\n", "\n", "# Define URLs and file paths\n", "files_to_download = {\n", " \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt\":\n", " \"pretrained_models/vae/modelf16.ckpt\",\n", " \"https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth\":\n", " \"pretrained_models/mar/city768.16.pth\"\n", "}\n", "\n", "for url, path in files_to_download.items():\n", " os.makedirs(os.path.dirname(path), exist_ok=True)\n", " \n", " if os.path.exists(path):\n", " print(f\"File already exists: {path} — skipping download.\")\n", " continue\n", "\n", " print(f\"Downloading from {url}...\")\n", " response = requests.get(url, stream=True)\n", " if response.status_code == 200:\n", " with open(path, 'wb') as f:\n", " for chunk in response.iter_content(chunk_size=8192):\n", " f.write(chunk)\n", " print(f\"Saved to {path}\")\n", " else:\n", " print(f\"Failed to download from {url}, status code {response.status_code}\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "3a7ac93b-1cbc-45f3-8ec5-8e8257a39786", "metadata": { "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "from tqdm import tqdm\n", "from PIL import Image\n", "import yaml\n", "import math\n", "\n", "import torch\n", "import torch.backends.cudnn as cudnn\n", "import torchvision.transforms as transforms\n", "\n", "from data import cityscapes, acdc, semantickitti, cadedgetune\n", "import util.misc as misc\n", "\n", "from models.vae import AutoencoderKL\n", "from models import mar" ] }, { "cell_type": "code", "execution_count": 4, "id": "e2bde6fd-9b39-40fd-8d4d-d0a5f9c8217a", "metadata": { "tags": [] }, "outputs": [], "source": [ "def mask_by_order(mask_len, order, bsz, seq_len):\n", " masking = torch.zeros(bsz, seq_len).cuda()\n", " masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()\n", " return masking\n", "\n", "def fast_hist(pred, label, n):\n", " k = (label >= 0) & (label < n)\n", " bin_count = np.bincount(\n", " n * label[k].astype(int) + pred[k], minlength=n ** 2)\n", " return bin_count[:n ** 2].reshape(n, n)\n", "\n", "color_pallete = np.round(np.array([\n", " 0, 0, 0,\n", " 128, 64, 128,\n", " 244, 35, 232,\n", " 70, 70, 70,\n", " 102, 102, 156,\n", " 190, 153, 153,\n", " 153, 153, 153,\n", " 250, 170, 30,\n", " 220, 220, 0,\n", " 107, 142, 35,\n", " 152, 251, 152,\n", " 0, 130, 180,\n", " 220, 20, 60,\n", " 255, 0, 0,\n", " 0, 0, 142,\n", " 0, 0, 70,\n", " 0, 60, 100,\n", " 0, 80, 100,\n", " 0, 0, 230,\n", " 119, 11, 32,\n", " ])/255.0, 4)\n", "\n", "color_pallete = color_pallete.reshape(-1, 3)" ] }, { "cell_type": "code", "execution_count": 5, "id": "c189ac7b-ccff-4745-af56-460ec88770b4", "metadata": { "tags": [] }, "outputs": [], "source": [ "device = torch.device(args.device)\n", "device = torch.device('cuda:0')\n", "args.batch_size = 1\n", "\n", "# fix the seed for reproducibility\n", "seed = args.seed + misc.get_rank()\n", "torch.manual_seed(seed)\n", "np.random.seed(seed)\n", "\n", "cudnn.benchmark = True\n", "\n", "num_tasks = misc.get_world_size()\n", "global_rank = misc.get_rank()" ] }, { "cell_type": "code", "execution_count": 9, "id": "28d13453-a3ac-4d2e-8906-0c179e85c2f9", "metadata": { "tags": [] }, "outputs": [], "source": [ "transform_train = transforms.Compose([\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])\n", "])\n", "\n", "dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n", "# dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n", "# dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n", "# dataset_train = cadedgetune.CADEdgeTune('dataset/CADEdgeTune/all.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)\n", "\n", "\n", "sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=1, rank=0, shuffle=False)\n", "\n", "data_loader_train = torch.utils.data.DataLoader(\n", " dataset_train, sampler=sampler_train,\n", " batch_size=args.batch_size,\n", " num_workers=args.num_workers,\n", " pin_memory=args.pin_mem,\n", " drop_last=True,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "2e22d231-02db-4586-b489-01a97314aed9", "metadata": { "tags": [] }, "outputs": [], "source": [ "vae = AutoencoderKL(\n", " ddconfig=args.ddconfig,\n", " embed_dim=args.vae_embed_dim,\n", " ckpt_path=args.vae_path\n", ").to(device).eval()\n", "\n", "for param in vae.parameters():\n", " param.requires_grad = False\n", " \n", "model = mar.mar_base(\n", " img_size=args.img_size,\n", " vae_stride=args.vae_stride,\n", " patch_size=args.patch_size,\n", " vae_embed_dim=args.vae_embed_dim,\n", " mask_ratio_min=args.mask_ratio_min,\n", " label_drop_prob=args.label_drop_prob,\n", " attn_dropout=args.attn_dropout,\n", " proj_dropout=args.proj_dropout,\n", " buffer_size=args.buffer_size,\n", " diffloss_d=args.diffloss_d,\n", " diffloss_w=args.diffloss_w,\n", " num_sampling_steps=args.num_sampling_steps,\n", " diffusion_batch_mul=args.diffusion_batch_mul,\n", " grad_checkpointing=args.grad_checkpointing,\n", ")\n", "\n", "n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "print(\"Number of trainable parameters: {}M\".format(n_params / 1e6))\n", "\n", "\n", "checkpoint = torch.load(args.ckpt_path, map_location='cpu')\n", "model.load_state_dict(checkpoint['model'])\n", "model.to(device)\n", "\n", "eff_batch_size = args.batch_size * misc.get_world_size()\n", "\n", "print(\"effective batch size: %d\" % eff_batch_size)" ] }, { "cell_type": "code", "execution_count": 8, "id": "4c83c0eb-35a5-4241-b869-d52eb6cd31e0", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Training Progress: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [13:11<00:00, 1.58s/it]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "road : 98.06\n", "sidewalk : 86.32\n", "building : 89.23\n", "wall : 47.44\n", "fence : 43.78\n", "pole : 60.14\n", "tlight : 63.16\n", "tsign : 82.48\n", "vtation : 92.72\n", "terrain : 80.45\n", "sky : 95.99\n", "person : 70.83\n", "rider : 64.25\n", "car : 94.06\n", "truck : 44.90\n", "bus : 66.81\n", "train : 44.04\n", "motorcycle : 47.34\n", "bicycle : 62.50\n", "Avg Pre : 70.24\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "hist = []\n", "model.eval()\n", "for data_iter_step, (samples, labels, path) in enumerate(tqdm(data_loader_train, desc=\"Training Progress\")):\n", " samples = samples.to(device, non_blocking=True)\n", " labels = labels.to(device, non_blocking=True)\n", "\n", " with torch.no_grad():\n", " posterior_x = vae.encode(samples)\n", " posterior_y = vae.encode(labels)\n", " x = posterior_x.sample().mul_(0.2325)\n", " y = posterior_y.sample().mul_(0.2325)\n", " x = model.patchify(x)\n", " y = model.patchify(y)\n", " gt_latents = y.clone().detach()\n", " cfg_iter = 1.0\n", " temperature = 1.0\n", " mask_actual = torch.cat([torch.zeros(args.batch_size, model.seq_len), torch.ones(args.batch_size, model.seq_len)], dim=1).cuda()\n", " tokens = torch.zeros(args.batch_size, model.seq_len, model.token_embed_dim).cuda()\n", "\n", " with torch.no_grad():\n", " x1 = model.forward_mae_encoder(x, mask_actual, tokens)\n", " z = model.forward_mae_decoder(x1, mask_actual)\n", " z = z[0]\n", " sampled_token_latent = model.diffloss.sample(z, temperature, cfg_iter)\n", "\n", " tokens[0] = sampled_token_latent[model.seq_len:]\n", " tokens = model.unpatchify(tokens)\n", " \n", " sampled_images = vae.decode(tokens / 0.2325)\n", " \n", " image_tensor = labels[0] \n", " image_tensor = image_tensor * 0.5 + 0.5\n", " gt_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n", " H, W, _ = gt_np.shape\n", " pixels = gt_np.reshape(-1, 3)\n", " distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n", " output = np.argmin(distances, axis=1)\n", " gt = output.reshape(H, W)\n", " \n", " image_tensor = sampled_images[0]\n", " image_tensor = image_tensor * 0.5 + 0.5 \n", " ss_np = image_tensor.permute(1, 2, 0).cpu().numpy()\n", " H, W, _ = ss_np.shape\n", " pixels = ss_np.reshape(-1, 3)\n", " distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)\n", " output = np.argmin(distances, axis=1)\n", " output = output.reshape(H, W)\n", " \n", " hist.append(fast_hist(output.reshape(-1), gt.reshape(-1), 20))\n", "\n", "cm = np.sum(hist, axis=0)\n", "\n", "epsilon = 1e-10\n", "class_precision = np.diag(cm[1:,1:]) / (np.sum(cm[1:,1:], axis=0) + epsilon)\n", "class_names = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'tlight', 'tsign', \n", " 'vtation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', \n", " 'motorcycle', 'bicycle']\n", "\n", "for i in range(len(class_names)):\n", " print(f\"{class_names[i]:<12}: {class_precision[i]*100:6.2f}\")\n", "average_precision = np.mean(class_precision)\n", "print(f\"{'Avg Pre':<12}: {average_precision*100:6.2f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 5 }