In [1]:
%run get_parser.py

In [None]:
import os
import requests

# Define URLs and file paths
files_to_download = {
 "https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/vae/modelf16.ckpt":
 "pretrained_models/vae/modelf16.ckpt",
 "https://huggingface.co/mahmed10/CAM-Seg/resolve/main/pretrained_models/mar/city768.16.pth":
 "pretrained_models/mar/city768.16.pth"
}

for url, path in files_to_download.items():
 os.makedirs(os.path.dirname(path), exist_ok=True)
 
 if os.path.exists(path):
 print(f"File already exists: {path} — skipping download.")
 continue

 print(f"Downloading from {url}...")
 response = requests.get(url, stream=True)
 if response.status_code == 200:
 with open(path, 'wb') as f:
 for chunk in response.iter_content(chunk_size=8192):
 f.write(chunk)
 print(f"Saved to {path}")
 else:
 print(f"Failed to download from {url}, status code {response.status_code}")

In [3]:
import numpy as np
from tqdm import tqdm
from PIL import Image
import yaml
import math

import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

from data import cityscapes, acdc, semantickitti, cadedgetune
import util.misc as misc

from models.vae import AutoencoderKL
from models import mar

In [4]:
def mask_by_order(mask_len, order, bsz, seq_len):
 masking = torch.zeros(bsz, seq_len).cuda()
 masking = torch.scatter(masking, dim=-1, index=order[:, :mask_len.long()], src=torch.ones(bsz, seq_len).cuda()).bool()
 return masking

def fast_hist(pred, label, n):
 k = (label >= 0) & (label < n)
 bin_count = np.bincount(
 n * label[k].astype(int) + pred[k], minlength=n ** 2)
 return bin_count[:n ** 2].reshape(n, n)

color_pallete = np.round(np.array([
 0, 0, 0,
 128, 64, 128,
 244, 35, 232,
 70, 70, 70,
 102, 102, 156,
 190, 153, 153,
 153, 153, 153,
 250, 170, 30,
 220, 220, 0,
 107, 142, 35,
 152, 251, 152,
 0, 130, 180,
 220, 20, 60,
 255, 0, 0,
 0, 0, 142,
 0, 0, 70,
 0, 60, 100,
 0, 80, 100,
 0, 0, 230,
 119, 11, 32,
 ])/255.0, 4)

color_pallete = color_pallete.reshape(-1, 3)

In [5]:
device = torch.device(args.device)
device = torch.device('cuda:0')
args.batch_size = 1

# fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

num_tasks = misc.get_world_size()
global_rank = misc.get_rank()

In [9]:
transform_train = transforms.Compose([
 transforms.ToTensor(),
 transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dataset_train = cityscapes.CityScapes('dataset/CityScapes/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)
# dataset_train = acdc.ACDC('dataset/ACDC/vallist_fog.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)
# dataset_train = semantickitti.SemanticKITTI('dataset/SemanticKitti/vallist.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)
# dataset_train = cadedgetune.CADEdgeTune('dataset/CADEdgeTune/all.txt', data_set= 'val', transform=transform_train, seed=args.seed, img_size=args.img_size)


sampler_train = torch.utils.data.DistributedSampler(dataset_train, num_replicas=1, rank=0, shuffle=False)

data_loader_train = torch.utils.data.DataLoader(
 dataset_train, sampler=sampler_train,
 batch_size=args.batch_size,
 num_workers=args.num_workers,
 pin_memory=args.pin_mem,
 drop_last=True,
)

In [None]:
vae = AutoencoderKL(
 ddconfig=args.ddconfig,
 embed_dim=args.vae_embed_dim,
 ckpt_path=args.vae_path
).to(device).eval()

for param in vae.parameters():
 param.requires_grad = False
 
model = mar.mar_base(
 img_size=args.img_size,
 vae_stride=args.vae_stride,
 patch_size=args.patch_size,
 vae_embed_dim=args.vae_embed_dim,
 mask_ratio_min=args.mask_ratio_min,
 label_drop_prob=args.label_drop_prob,
 attn_dropout=args.attn_dropout,
 proj_dropout=args.proj_dropout,
 buffer_size=args.buffer_size,
 diffloss_d=args.diffloss_d,
 diffloss_w=args.diffloss_w,
 num_sampling_steps=args.num_sampling_steps,
 diffusion_batch_mul=args.diffusion_batch_mul,
 grad_checkpointing=args.grad_checkpointing,
)

n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Number of trainable parameters: {}M".format(n_params / 1e6))


checkpoint = torch.load(args.ckpt_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.to(device)

eff_batch_size = args.batch_size * misc.get_world_size()

print("effective batch size: %d" % eff_batch_size)

In [8]:
hist = []
model.eval()
for data_iter_step, (samples, labels, path) in enumerate(tqdm(data_loader_train, desc="Training Progress")):
 samples = samples.to(device, non_blocking=True)
 labels = labels.to(device, non_blocking=True)

 with torch.no_grad():
 posterior_x = vae.encode(samples)
 posterior_y = vae.encode(labels)
 x = posterior_x.sample().mul_(0.2325)
 y = posterior_y.sample().mul_(0.2325)
 x = model.patchify(x)
 y = model.patchify(y)
 gt_latents = y.clone().detach()
 cfg_iter = 1.0
 temperature = 1.0
 mask_actual = torch.cat([torch.zeros(args.batch_size, model.seq_len), torch.ones(args.batch_size, model.seq_len)], dim=1).cuda()
 tokens = torch.zeros(args.batch_size, model.seq_len, model.token_embed_dim).cuda()

 with torch.no_grad():
 x1 = model.forward_mae_encoder(x, mask_actual, tokens)
 z = model.forward_mae_decoder(x1, mask_actual)
 z = z[0]
 sampled_token_latent = model.diffloss.sample(z, temperature, cfg_iter)

 tokens[0] = sampled_token_latent[model.seq_len:]
 tokens = model.unpatchify(tokens)
 
 sampled_images = vae.decode(tokens / 0.2325)
 
 image_tensor = labels[0] 
 image_tensor = image_tensor * 0.5 + 0.5
 gt_np = image_tensor.permute(1, 2, 0).cpu().numpy()
 H, W, _ = gt_np.shape
 pixels = gt_np.reshape(-1, 3)
 distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)
 output = np.argmin(distances, axis=1)
 gt = output.reshape(H, W)
 
 image_tensor = sampled_images[0]
 image_tensor = image_tensor * 0.5 + 0.5 
 ss_np = image_tensor.permute(1, 2, 0).cpu().numpy()
 H, W, _ = ss_np.shape
 pixels = ss_np.reshape(-1, 3)
 distances = np.linalg.norm(pixels[:, None, :] - color_pallete[None, :, :], axis=2)
 output = np.argmin(distances, axis=1)
 output = output.reshape(H, W)
 
 hist.append(fast_hist(output.reshape(-1), gt.reshape(-1), 20))

cm = np.sum(hist, axis=0)

epsilon = 1e-10
class_precision = np.diag(cm[1:,1:]) / (np.sum(cm[1:,1:], axis=0) + epsilon)
class_names = ['road', 'sidewalk', 'building', 'wall', 'fence', 'pole', 'tlight', 'tsign', 
 'vtation', 'terrain', 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 
 'motorcycle', 'bicycle']

for i in range(len(class_names)):
 print(f"{class_names[i]:<12}: {class_precision[i]*100:6.2f}")
average_precision = np.mean(class_precision)
print(f"{'Avg Pre':<12}: {average_precision*100:6.2f}")

Training Progress: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [13:11<00:00, 1.58s/it]

road : 98.06
sidewalk : 86.32
building : 89.23
wall : 47.44
fence : 43.78
pole : 60.14
tlight : 63.16
tsign : 82.48
vtation : 92.72
terrain : 80.45
sky : 95.99
person : 70.83
rider : 64.25
car : 94.06
truck : 44.90
bus : 66.81
train : 44.04
motorcycle : 47.34
bicycle : 62.50
Avg Pre : 70.24



