Spaces:
Sleeping
Sleeping
""" | |
PaveCLIP: Complete CLIP Training Framework for Pavement Data | |
Supports ViT/ResNet encoders, BERT/custom text encoders, SigLIP, Multi-GPU training | |
""" | |
import os | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.distributed as dist | |
from torch.utils.data import Dataset, DataLoader | |
from torch.utils.data.distributed import DistributedSampler | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import torchvision.transforms as transforms | |
from torchvision.models import resnet50, resnet101 | |
import timm | |
from transformers import AutoTokenizer, AutoModel, BertModel, RobertaModel | |
from PIL import Image | |
import numpy as np | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
from sklearn.metrics.pairwise import cosine_similarity | |
import logging | |
from typing import Dict, List, Tuple, Optional, Union | |
import argparse | |
import time | |
import wandb | |
from tqdm import tqdm | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class PavementDataset(Dataset): | |
""" | |
Dataset loader for pavement pretraining data with complex folder structure | |
""" | |
def __init__(self, data_dir: str, transform=None, tokenizer=None, max_length=77): | |
self.data_dir = Path(data_dir) | |
self.transform = transform | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
self.samples = [] | |
logger.info(f"Loading dataset from {data_dir}") | |
self._load_dataset() | |
logger.info(f"Loaded {len(self.samples)} samples from {self._get_unique_images()} unique images") | |
def _load_dataset(self): | |
"""Load all JSON files and collect image-text pairs""" | |
json_files = list(self.data_dir.rglob("*.json")) | |
for json_file in json_files: | |
try: | |
with open(json_file, 'r') as f: | |
data = json.load(f) | |
# Handle different JSON structures | |
if isinstance(data, list): | |
# List of samples | |
for item in data: | |
self._process_sample(item, json_file.parent) | |
elif isinstance(data, dict): | |
# Single sample or nested structure | |
if "conversations" in data: | |
self._process_sample(data, json_file.parent) | |
else: | |
# Check if it's a collection | |
for key, value in data.items(): | |
if isinstance(value, dict) and "conversations" in value: | |
self._process_sample(value, json_file.parent) | |
elif isinstance(value, list): | |
for item in value: | |
if isinstance(item, dict) and "conversations" in item: | |
self._process_sample(item, json_file.parent) | |
except Exception as e: | |
logger.warning(f"Error loading {json_file}: {e}") | |
def _process_sample(self, sample: dict, base_path: Path): | |
"""Process individual sample and extract image-text pair""" | |
try: | |
image_path = sample.get("image", "") | |
conversations = sample.get("conversations", []) | |
if not image_path or not conversations: | |
return | |
# Find text response from GPT | |
text = "" | |
for conv in conversations: | |
if conv.get("from") == "gpt": | |
text = conv.get("value", "") | |
break | |
if not text: | |
return | |
# Resolve image path (relative to base_path) | |
full_image_path = base_path / image_path | |
if not full_image_path.exists(): | |
# Try different relative paths | |
for possible_base in [base_path, base_path.parent, base_path.parent.parent]: | |
test_path = possible_base / image_path | |
if test_path.exists(): | |
full_image_path = test_path | |
break | |
if full_image_path.exists(): | |
self.samples.append({ | |
"image_path": str(full_image_path), | |
"text": text.strip(), | |
"id": sample.get("id", f"sample_{len(self.samples)}") | |
}) | |
except Exception as e: | |
logger.warning(f"Error processing sample: {e}") | |
def _get_unique_images(self): | |
"""Get count of unique images""" | |
return len(set(sample["image_path"] for sample in self.samples)) | |
def __len__(self): | |
return len(self.samples) | |
def __getitem__(self, idx): | |
sample = self.samples[idx] | |
# Load and transform image | |
try: | |
image = Image.open(sample["image_path"]).convert("RGB") | |
if self.transform: | |
image = self.transform(image) | |
except Exception as e: | |
logger.warning(f"Error loading image {sample['image_path']}: {e}") | |
# Return a black image as fallback | |
image = torch.zeros(3, 224, 224) | |
# Tokenize text | |
text = sample["text"] | |
if self.tokenizer: | |
tokens = self.tokenizer( | |
text, | |
max_length=self.max_length, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
return { | |
"image": image, | |
"input_ids": tokens["input_ids"].squeeze(), | |
"attention_mask": tokens["attention_mask"].squeeze(), | |
"text": text | |
} | |
else: | |
return { | |
"image": image, | |
"text": text | |
} | |
class VisionEncoder(nn.Module): | |
"""Flexible vision encoder supporting ViT and ResNet architectures""" | |
def __init__(self, model_name: str, embed_dim: int = 512, pretrained: bool = True): | |
super().__init__() | |
self.model_name = model_name | |
self.embed_dim = embed_dim | |
self.expected_image_size = 224 # Default | |
# Try to determine architecture type | |
if any(arch in model_name.lower() for arch in ["vit", "deit", "swin", "beit", "cait"]): | |
self._setup_vit(model_name, pretrained) | |
elif "resnet" in model_name.lower(): | |
self._setup_resnet(model_name, pretrained) | |
else: | |
# 🔧 GENERIC TIMM MODEL LOADING | |
self._setup_generic_timm(model_name, pretrained) | |
# Projection head | |
self.projection = nn.Linear(self.feature_dim, embed_dim) | |
def _setup_generic_timm(self, model_name: str, pretrained: bool): | |
"""Setup any TIMM model generically""" | |
try: | |
self.backbone = timm.create_model( | |
model_name, | |
pretrained=pretrained, | |
num_classes=0 # Remove classification head | |
) | |
# Auto-detect input size and feature dimension | |
self.feature_dim = None | |
test_sizes = [224, 288, 336, 384, 448, 512] | |
for test_size in test_sizes: | |
try: | |
with torch.no_grad(): | |
dummy_input = torch.randn(1, 3, test_size, test_size) | |
features = self.backbone(dummy_input) | |
# Handle different output formats | |
if len(features.shape) > 2: | |
features = features.view(features.size(0), -1) | |
self.feature_dim = features.shape[1] | |
self.expected_image_size = test_size | |
logger.info(f"Generic model {model_name} expects {test_size}x{test_size} → {self.feature_dim}D") | |
break | |
except Exception: | |
continue | |
if self.feature_dim is None: | |
raise Exception("Could not determine model specifications") | |
except Exception as e: | |
logger.error(f"Failed to load {model_name}: {e}") | |
raise | |
def _setup_vit(self, model_name: str, pretrained: bool): | |
"""Setup Vision Transformer - works with any TIMM ViT model""" | |
# Known mappings for convenience | |
vit_mapping = { | |
"vit-b/16": "vit_base_patch16_224", | |
"vit-b/32": "vit_base_patch32_224", | |
"vit-l/14": "vit_large_patch14_224", | |
"vit-l/14@336": "vit_large_patch14_clip_336", | |
"vit-h/14": "vit_huge_patch14_clip_378" | |
} | |
# Use mapping if available, otherwise use model name directly | |
timm_name = vit_mapping.get(model_name.lower(), model_name) | |
try: | |
self.backbone = timm.create_model( | |
timm_name, | |
pretrained=pretrained, | |
num_classes=0 | |
) | |
# 🔧 AUTO-DETECT input size by trying common sizes | |
self.feature_dim = None | |
test_sizes = [224, 336, 378, 384, 512] # Common ViT sizes | |
for test_size in test_sizes: | |
try: | |
with torch.no_grad(): | |
dummy_input = torch.randn(1, 3, test_size, test_size) | |
features = self.backbone(dummy_input) | |
self.feature_dim = features.shape[1] | |
self.expected_image_size = test_size | |
logger.info(f"Model {timm_name} expects {test_size}x{test_size} input") | |
break | |
except Exception: | |
continue | |
if self.feature_dim is None: | |
raise Exception("Could not determine input size for model") | |
except Exception as e: | |
logger.warning(f"Failed to load {timm_name}: {e}") | |
logger.warning("Falling back to basic ViT") | |
self.backbone = timm.create_model("vit_base_patch16_224", pretrained=pretrained, num_classes=0) | |
self.feature_dim = 768 | |
self.expected_image_size = 224 | |
def _setup_resnet(self, model_name: str, pretrained: bool): | |
"""Setup ResNet""" | |
if "resnet50" in model_name.lower(): | |
self.backbone = resnet50(pretrained=pretrained) | |
elif "resnet101" in model_name.lower(): | |
self.backbone = resnet101(pretrained=pretrained) | |
else: | |
self.backbone = resnet50(pretrained=pretrained) | |
# Remove classification head | |
self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) | |
self.feature_dim = 2048 # ResNet feature dimension | |
def forward(self, x): | |
features = self.backbone(x) | |
if len(features.shape) > 2: | |
features = features.view(features.size(0), -1) | |
return self.projection(features) | |
class TextEncoder(nn.Module): | |
"""Flexible text encoder supporting various transformer models""" | |
def __init__(self, model_name: str = "bert-base-uncased", embed_dim: int = 512, | |
max_length: int = 77, pretrained: bool = True): | |
super().__init__() | |
self.model_name = model_name | |
self.embed_dim = embed_dim | |
self.max_length = max_length | |
if not pretrained: | |
# Initialize from scratch | |
if "bert" in model_name: | |
from transformers import BertConfig | |
config = BertConfig(vocab_size=30522, max_position_embeddings=max_length) | |
self.transformer = BertModel(config) | |
else: | |
self.transformer = AutoModel.from_pretrained(model_name, | |
ignore_mismatched_sizes=True) | |
else: | |
self.transformer = AutoModel.from_pretrained(model_name) | |
# Get hidden dimension | |
self.hidden_dim = self.transformer.config.hidden_size | |
# Projection head | |
self.projection = nn.Linear(self.hidden_dim, embed_dim) | |
def forward(self, input_ids, attention_mask=None): | |
outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) | |
# Use [CLS] token or mean pooling | |
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: | |
features = outputs.pooler_output | |
else: | |
# Mean pooling over sequence length | |
features = outputs.last_hidden_state.mean(dim=1) | |
return self.projection(features) | |
class CLIPModel(nn.Module): | |
"""CLIP model with contrastive learning""" | |
def __init__(self, vision_model: str, text_model: str, embed_dim: int = 512, | |
temperature: float = 0.07, vision_pretrained: bool = True, | |
text_pretrained: bool = True): | |
super().__init__() | |
self.vision_encoder = VisionEncoder(vision_model, embed_dim, vision_pretrained) | |
self.text_encoder = TextEncoder(text_model, embed_dim, pretrained=text_pretrained) | |
# Temperature parameter for contrastive loss | |
self.temperature = nn.Parameter(torch.tensor(temperature)) | |
def forward(self, images, input_ids, attention_mask=None): | |
# Encode images and text | |
image_features = self.vision_encoder(images) | |
text_features = self.text_encoder(input_ids, attention_mask) | |
# Normalize features | |
image_features = F.normalize(image_features, p=2, dim=1) | |
text_features = F.normalize(text_features, p=2, dim=1) | |
return image_features, text_features | |
def compute_loss(self, image_features, text_features): | |
"""Compute contrastive loss""" | |
batch_size = image_features.shape[0] | |
# Compute similarity matrix | |
logits = torch.matmul(image_features, text_features.T) / self.temperature | |
# Labels are diagonal (each image matches its corresponding text) | |
labels = torch.arange(batch_size, device=logits.device) | |
# Compute cross-entropy loss for both directions | |
loss_img = F.cross_entropy(logits, labels) | |
loss_txt = F.cross_entropy(logits.T, labels) | |
return (loss_img + loss_txt) / 2 | |
class SigLIPModel(nn.Module): | |
"""SigLIP model with sigmoid loss instead of contrastive loss""" | |
def __init__(self, vision_model: str, text_model: str, embed_dim: int = 512, | |
temperature: float = 0.07, vision_pretrained: bool = True, | |
text_pretrained: bool = True): | |
super().__init__() | |
self.vision_encoder = VisionEncoder(vision_model, embed_dim, vision_pretrained) | |
self.text_encoder = TextEncoder(text_model, embed_dim, pretrained=text_pretrained) | |
# Temperature parameter | |
self.temperature = nn.Parameter(torch.tensor(temperature)) | |
def forward(self, images, input_ids, attention_mask=None): | |
# Encode images and text | |
image_features = self.vision_encoder(images) | |
text_features = self.text_encoder(input_ids, attention_mask) | |
# Normalize features | |
image_features = F.normalize(image_features, p=2, dim=1) | |
text_features = F.normalize(text_features, p=2, dim=1) | |
return image_features, text_features | |
def compute_loss(self, image_features, text_features): | |
"""Compute SigLIP loss""" | |
batch_size = image_features.shape[0] | |
# Compute similarity matrix | |
logits = torch.matmul(image_features, text_features.T) / self.temperature | |
# Create positive and negative labels | |
labels = torch.eye(batch_size, device=logits.device) | |
labels = labels * 2 - 1 # Convert to -1/1 labels | |
# SigLIP loss: -log(sigmoid(z_i * y_i)) | |
loss = -F.logsigmoid(logits * labels).mean() | |
return loss | |
class PaveCLIPTrainer: | |
"""Complete training framework for PaveCLIP""" | |
def __init__(self, config: Dict): | |
self.config = config | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.distributed = False | |
self.rank = 0 | |
# Setup distributed training if specified | |
if config.get("distributed", False): | |
self._setup_distributed() | |
# Initialize model | |
self._setup_model() | |
# Setup data | |
self._setup_data() | |
# Setup optimization | |
self._setup_optimization() | |
# Setup logging | |
if config.get("wandb", False) and (not self.distributed or self.rank == 0): | |
wandb.init(project="paveclip", config=config) | |
def _setup_distributed(self): | |
"""Setup distributed training""" | |
self.distributed = True | |
self.rank = int(os.environ.get("LOCAL_RANK", 0)) | |
self.world_size = int(os.environ.get("WORLD_SIZE", 1)) | |
dist.init_process_group(backend="nccl") | |
torch.cuda.set_device(self.rank) | |
self.device = torch.device(f"cuda:{self.rank}") | |
logger.info(f"Initialized distributed training: rank {self.rank}/{self.world_size}") | |
def _setup_model(self): | |
"""Initialize the model""" | |
model_type = self.config.get("model_type", "clip").lower() | |
if model_type == "clip": | |
self.model = CLIPModel( | |
vision_model=self.config["vision_model"], | |
text_model=self.config["text_model"], | |
embed_dim=self.config.get("embed_dim", 512), | |
temperature=self.config.get("temperature", 0.07), | |
vision_pretrained=self.config.get("vision_pretrained", True), | |
text_pretrained=self.config.get("text_pretrained", True) | |
) | |
elif model_type == "siglip": | |
self.model = SigLIPModel( | |
vision_model=self.config["vision_model"], | |
text_model=self.config["text_model"], | |
embed_dim=self.config.get("embed_dim", 512), | |
temperature=self.config.get("temperature", 0.07), | |
vision_pretrained=self.config.get("vision_pretrained", True), | |
text_pretrained=self.config.get("text_pretrained", True) | |
) | |
else: | |
raise ValueError(f"Unsupported model type: {model_type}") | |
self.model = self.model.to(self.device) | |
# Wrap with DDP for distributed training | |
if hasattr(self, 'distributed') and self.distributed: | |
self.model = DDP(self.model, device_ids=[self.rank]) | |
def _setup_data(self): | |
"""Setup data loaders""" | |
# Image transforms | |
if "vit" in self.config["vision_model"].lower(): | |
image_size = 336 if "@336" in self.config["vision_model"] else 224 | |
else: | |
image_size = 224 | |
# Pavement-specific augmentations for robustness | |
train_transform = transforms.Compose([ | |
transforms.Resize((image_size, image_size)), | |
transforms.RandomHorizontalFlip(p=0.5), | |
transforms.RandomRotation(degrees=15), # Roads can be at angles | |
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05), | |
transforms.RandomGrayscale(p=0.1), # Some pavement images are grayscale | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
val_transform = transforms.Compose([ | |
transforms.Resize((image_size, image_size)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
# Tokenizer | |
from transformers import AutoTokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(self.config["text_model"]) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Dataset | |
train_dataset = PavementDataset( | |
self.config["data_dir"], | |
transform=train_transform, | |
tokenizer=self.tokenizer, | |
max_length=self.config.get("max_length", 77) | |
) | |
# Split for validation if specified | |
if self.config.get("val_split", 0.1) > 0: | |
val_size = int(len(train_dataset) * self.config["val_split"]) | |
train_size = len(train_dataset) - val_size | |
train_dataset, val_dataset = torch.utils.data.random_split( | |
train_dataset, [train_size, val_size] | |
) | |
val_dataset.dataset.transform = val_transform | |
else: | |
val_dataset = None | |
# Data loaders | |
train_sampler = DistributedSampler(train_dataset) if hasattr(self, 'distributed') and self.distributed else None | |
self.train_loader = DataLoader( | |
train_dataset, | |
batch_size=self.config["batch_size"], | |
shuffle=(train_sampler is None), | |
sampler=train_sampler, | |
num_workers=self.config.get("num_workers", 4), | |
pin_memory=True, | |
drop_last=True | |
) | |
if val_dataset: | |
val_sampler = DistributedSampler(val_dataset) if hasattr(self, 'distributed') and self.distributed else None | |
self.val_loader = DataLoader( | |
val_dataset, | |
batch_size=self.config["batch_size"], | |
shuffle=False, | |
sampler=val_sampler, | |
num_workers=self.config.get("num_workers", 4), | |
pin_memory=True | |
) | |
else: | |
self.val_loader = None | |
logger.info(f"Training samples: {len(train_dataset)}") | |
if val_dataset: | |
logger.info(f"Validation samples: {len(val_dataset)}") | |
def _setup_optimization(self): | |
"""Setup optimizer and scheduler""" | |
# Pavement-specific optimization strategy | |
# Different learning rates for vision and text encoders | |
vision_params = [] | |
text_params = [] | |
other_params = [] | |
model = self.model.module if hasattr(self.model, 'module') else self.model | |
for name, param in model.named_parameters(): | |
if 'vision_encoder' in name: | |
vision_params.append(param) | |
elif 'text_encoder' in name: | |
text_params.append(param) | |
else: | |
other_params.append(param) | |
# Different learning rates for different components | |
param_groups = [ | |
{'params': vision_params, 'lr': self.config["learning_rate"] * 0.1}, # Lower LR for vision | |
{'params': text_params, 'lr': self.config["learning_rate"]}, # Standard LR for text | |
{'params': other_params, 'lr': self.config["learning_rate"]} # Standard LR for others | |
] | |
self.optimizer = torch.optim.AdamW( | |
param_groups, | |
weight_decay=self.config.get("weight_decay", 0.01) | |
) | |
# Learning rate scheduler | |
total_steps = len(self.train_loader) * self.config["epochs"] | |
warmup_steps = int(total_steps * self.config.get("warmup_ratio", 0.1)) | |
self.scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
self.optimizer, | |
max_lr=[group['lr'] for group in param_groups], | |
total_steps=total_steps, | |
pct_start=warmup_steps / total_steps, | |
anneal_strategy='cos' | |
) | |
def train_epoch(self, epoch: int): | |
"""Train for one epoch""" | |
self.model.train() | |
if hasattr(self, 'distributed') and self.distributed: | |
self.train_loader.sampler.set_epoch(epoch) | |
total_loss = 0 | |
num_batches = len(self.train_loader) | |
pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}") if (not hasattr(self, 'distributed') or self.rank == 0) else self.train_loader | |
for batch_idx, batch in enumerate(pbar): | |
images = batch["image"].to(self.device, non_blocking=True) | |
input_ids = batch["input_ids"].to(self.device, non_blocking=True) | |
attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) | |
# Forward pass | |
image_features, text_features = self.model(images, input_ids, attention_mask) | |
# Compute loss | |
loss = self.model.module.compute_loss(image_features, text_features) if hasattr(self.model, 'module') else self.model.compute_loss(image_features, text_features) | |
# Backward pass | |
self.optimizer.zero_grad() | |
loss.backward() | |
# Gradient clipping for stability | |
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) | |
self.optimizer.step() | |
self.scheduler.step() | |
total_loss += loss.item() | |
# Update progress bar | |
if hasattr(pbar, 'set_postfix'): | |
pbar.set_postfix({ | |
'loss': f'{loss.item():.4f}', | |
'avg_loss': f'{total_loss/(batch_idx+1):.4f}', | |
'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' | |
}) | |
# Log to wandb | |
if self.config.get("wandb", False) and (not hasattr(self, 'distributed') or self.rank == 0): | |
wandb.log({ | |
"train_loss": loss.item(), | |
"learning_rate": self.scheduler.get_last_lr()[0], | |
"epoch": epoch, | |
"step": epoch * num_batches + batch_idx | |
}) | |
return total_loss / num_batches | |
def validate(self, epoch: int): | |
"""Validate the model""" | |
if self.val_loader is None: | |
return None | |
self.model.eval() | |
total_loss = 0 | |
with torch.no_grad(): | |
for batch in self.val_loader: | |
images = batch["image"].to(self.device, non_blocking=True) | |
input_ids = batch["input_ids"].to(self.device, non_blocking=True) | |
attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) | |
# Forward pass | |
image_features, text_features = self.model(images, input_ids, attention_mask) | |
# Compute loss | |
loss = self.model.module.compute_loss(image_features, text_features) if hasattr(self.model, 'module') else self.model.compute_loss(image_features, text_features) | |
total_loss += loss.item() | |
avg_loss = total_loss / len(self.val_loader) | |
if self.config.get("wandb", False) and (not hasattr(self, 'distributed') or self.rank == 0): | |
wandb.log({ | |
"val_loss": avg_loss, | |
"epoch": epoch | |
}) | |
return avg_loss | |
def train(self): | |
"""Main training loop""" | |
logger.info("Starting training...") | |
best_val_loss = float('inf') | |
for epoch in range(self.config["epochs"]): | |
# Train | |
train_loss = self.train_epoch(epoch) | |
# Validate | |
val_loss = self.validate(epoch) | |
# Log epoch results | |
if not hasattr(self, 'distributed') or self.rank == 0: | |
logger.info(f"Epoch {epoch+1}/{self.config['epochs']}") | |
logger.info(f"Train Loss: {train_loss:.4f}") | |
if val_loss is not None: | |
logger.info(f"Val Loss: {val_loss:.4f}") | |
# Save checkpoint | |
if (not hasattr(self, 'distributed') or self.rank == 0) and val_loss is not None and val_loss < best_val_loss: | |
best_val_loss = val_loss | |
self.save_checkpoint(epoch, is_best=True) | |
# Regular checkpoint | |
if (epoch + 1) % self.config.get("save_every", 10) == 0: | |
if not hasattr(self, 'distributed') or self.rank == 0: | |
self.save_checkpoint(epoch, is_best=False) | |
def save_checkpoint(self, epoch: int, is_best: bool = False): | |
"""Save model checkpoint""" | |
model_state = self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict() | |
checkpoint = { | |
'epoch': epoch, | |
'model_state_dict': model_state, | |
'optimizer_state_dict': self.optimizer.state_dict(), | |
'config': self.config | |
} | |
filename = f"paveclip_epoch_{epoch+1}.pt" | |
if is_best: | |
filename = "paveclip_best.pt" | |
save_path = Path(self.config["output_dir"]) / filename | |
save_path.parent.mkdir(parents=True, exist_ok=True) | |
torch.save(checkpoint, save_path) | |
logger.info(f"Saved checkpoint: {save_path}") | |
class PaveCLIPEvaluator: | |
"""Evaluation utilities for PaveCLIP""" | |
def __init__(self, model_path: str, config: Dict): | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.config = config | |
# Load model | |
checkpoint = torch.load(model_path, map_location=self.device) | |
model_config = checkpoint['config'] | |
# Initialize model | |
if model_config.get("model_type", "clip").lower() == "clip": | |
self.model = CLIPModel( | |
vision_model=model_config["vision_model"], | |
text_model=model_config["text_model"], | |
embed_dim=model_config.get("embed_dim", 512) | |
) | |
else: | |
self.model = SigLIPModel( | |
vision_model=model_config["vision_model"], | |
text_model=model_config["text_model"], | |
embed_dim=model_config.get("embed_dim", 512) | |
) | |
self.model.load_state_dict(checkpoint['model_state_dict']) | |
self.model = self.model.to(self.device) | |
self.model.eval() | |
# Setup tokenizer and transforms | |
from transformers import AutoTokenizer | |
self.tokenizer = AutoTokenizer.from_pretrained(model_config["text_model"]) | |
if self.tokenizer.pad_token is None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
# Image transforms | |
#image_size = 336 if "@336" in model_config["vision_model"] else 224 | |
expected = getattr(self.model.vision_encoder, "expected_image_size", 224) | |
self.transform = transforms.Compose([ | |
transforms.Resize((expected, expected)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
self.image_size = expected # keep for later use | |
def encode_images(self, image_paths: List[str]) -> torch.Tensor: | |
"""Encode list of images""" | |
features = [] | |
with torch.no_grad(): | |
for img_path in image_paths: | |
image = Image.open(img_path).convert("RGB") | |
image = self.transform(image).unsqueeze(0).to(self.device) | |
img_features, _ = self.model(image, torch.zeros(1, 1).long().to(self.device)) | |
features.append(img_features.cpu()) | |
return torch.cat(features, dim=0) | |
def encode_texts(self, texts: List[str]) -> torch.Tensor: | |
"""Encode list of texts""" | |
tokens = self.tokenizer( | |
texts, | |
max_length=77, | |
padding='max_length', | |
truncation=True, | |
return_tensors='pt' | |
) | |
# with torch.no_grad(): | |
# tokens = {k: v.to(self.device) for k, v in tokens.items()} | |
# dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device) | |
# _, text_features = self.model(dummy_images, tokens["input_ids"], tokens["attention_mask"]) | |
# In PaveCLIPEvaluator.encode_texts | |
with torch.no_grad(): | |
tokens = {k: v.to(self.device) for k, v in tokens.items()} | |
text_features = self.model.text_encoder(tokens["input_ids"], tokens["attention_mask"]) | |
text_features = F.normalize(text_features, p=2, dim=1) | |
return text_features.cpu() | |
def zero_shot_classification(self, image_paths: List[str], class_texts: List[str]) -> Dict: | |
"""Perform zero-shot classification""" | |
logger.info("Performing zero-shot classification...") | |
# Encode images and texts | |
image_features = self.encode_images(image_paths) | |
text_features = self.encode_texts(class_texts) | |
# Compute similarities | |
similarities = torch.matmul(image_features, text_features.T) | |
predictions = similarities.argmax(dim=1) | |
# Compute accuracy if ground truth is available | |
results = { | |
"predictions": predictions.tolist(), | |
"similarities": similarities.tolist(), | |
"class_texts": class_texts | |
} | |
return results | |
def image_retrieval(self, query_text: str, image_paths: List[str], top_k: int = 5) -> List[Tuple[str, float]]: | |
"""Retrieve top-k images for a text query""" | |
logger.info(f"Retrieving top-{top_k} images for query: '{query_text}'") | |
# Encode query and images | |
text_features = self.encode_texts([query_text]) | |
image_features = self.encode_images(image_paths) | |
# Compute similarities | |
similarities = torch.matmul(text_features, image_features.T).squeeze() | |
# Get top-k results | |
top_k_indices = similarities.argsort(descending=True)[:top_k] | |
results = [] | |
for idx in top_k_indices: | |
results.append((image_paths[idx.item()], similarities[idx.item()].item())) | |
return results | |
def main(): | |
"""Main training script""" | |
parser = argparse.ArgumentParser(description="Train PaveCLIP model") | |
# Model arguments | |
parser.add_argument("--model_type", default="clip", choices=["clip", "siglip"], | |
help="Model type to train") | |
parser.add_argument("--vision_model", default="vit-b/16", | |
help="Vision encoder (e.g., vit-b/16, vit-l/14@336, resnet50)") | |
parser.add_argument("--text_model", default="bert-base-uncased", | |
help="Text encoder (e.g., bert-base-uncased, roberta-base)") | |
parser.add_argument("--embed_dim", type=int, default=512, | |
help="Embedding dimension") | |
parser.add_argument("--vision_pretrained", action="store_true", | |
help="Use pretrained vision encoder") | |
parser.add_argument("--text_pretrained", action="store_true", | |
help="Use pretrained text encoder") | |
# Data arguments | |
parser.add_argument("--data_dir", required=True, | |
help="Path to Pavement_Pretraining_Data directory") | |
parser.add_argument("--val_split", type=float, default=0.1, | |
help="Validation split ratio") | |
parser.add_argument("--max_length", type=int, default=77, | |
help="Maximum text length") | |
# Training arguments | |
parser.add_argument("--batch_size", type=int, default=64, | |
help="Batch size") | |
parser.add_argument("--epochs", type=int, default=50, | |
help="Number of epochs") | |
parser.add_argument("--learning_rate", type=float, default=1e-4, | |
help="Learning rate") | |
parser.add_argument("--weight_decay", type=float, default=0.01, | |
help="Weight decay") | |
parser.add_argument("--temperature", type=float, default=0.07, | |
help="Temperature parameter") | |
parser.add_argument("--warmup_ratio", type=float, default=0.1, | |
help="Warmup ratio") | |
# System arguments | |
parser.add_argument("--num_workers", type=int, default=4, | |
help="Number of data loader workers") | |
parser.add_argument("--output_dir", default="./checkpoints", | |
help="Output directory for checkpoints") | |
parser.add_argument("--save_every", type=int, default=10, | |
help="Save checkpoint every N epochs") | |
parser.add_argument("--wandb", action="store_true", | |
help="Use Weights & Biases logging") | |
parser.add_argument("--distributed", action="store_true", | |
help="Enable distributed training") | |
args = parser.parse_args() | |
# Convert args to config dict | |
config = vars(args) | |
# Initialize trainer | |
trainer = PaveCLIPTrainer(config) | |
# Start training | |
trainer.train() | |
# Cleanup distributed training | |
if config.get("distributed", False): | |
dist.destroy_process_group() | |
if __name__ == "__main__": | |
main() | |
# python paveclip_training.py \ | |
# --vision_model vit-b/16 \ | |
# --text_model distilbert-base-uncased \ | |
# --vision_pretrained \ | |
# --text_pretrained \ | |
# --data_dir ./Pavement_Pretraining_Data \ | |
# --batch_size 64 \ | |
# --epochs 100 \ | |
# --wandb | |
# torchrun --nproc_per_node=4 paveclip_training.py \ | |
# --distributed \ | |
# [other args] |