""" PaveCLIP: Complete CLIP Training Framework for Pavement Data Supports ViT/ResNet encoders, BERT/custom text encoders, SigLIP, Multi-GPU training """ import os import json import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist from torch.utils.data import Dataset, DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP import torchvision.transforms as transforms from torchvision.models import resnet50, resnet101 import timm from transformers import AutoTokenizer, AutoModel, BertModel, RobertaModel from PIL import Image import numpy as np from pathlib import Path import matplotlib.pyplot as plt from sklearn.metrics.pairwise import cosine_similarity import logging from typing import Dict, List, Tuple, Optional, Union import argparse import time import wandb from tqdm import tqdm import warnings warnings.filterwarnings("ignore") # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class PavementDataset(Dataset): """ Dataset loader for pavement pretraining data with complex folder structure """ def __init__(self, data_dir: str, transform=None, tokenizer=None, max_length=77): self.data_dir = Path(data_dir) self.transform = transform self.tokenizer = tokenizer self.max_length = max_length self.samples = [] logger.info(f"Loading dataset from {data_dir}") self._load_dataset() logger.info(f"Loaded {len(self.samples)} samples from {self._get_unique_images()} unique images") def _load_dataset(self): """Load all JSON files and collect image-text pairs""" json_files = list(self.data_dir.rglob("*.json")) for json_file in json_files: try: with open(json_file, 'r') as f: data = json.load(f) # Handle different JSON structures if isinstance(data, list): # List of samples for item in data: self._process_sample(item, json_file.parent) elif isinstance(data, dict): # Single sample or nested structure if "conversations" in data: self._process_sample(data, json_file.parent) else: # Check if it's a collection for key, value in data.items(): if isinstance(value, dict) and "conversations" in value: self._process_sample(value, json_file.parent) elif isinstance(value, list): for item in value: if isinstance(item, dict) and "conversations" in item: self._process_sample(item, json_file.parent) except Exception as e: logger.warning(f"Error loading {json_file}: {e}") def _process_sample(self, sample: dict, base_path: Path): """Process individual sample and extract image-text pair""" try: image_path = sample.get("image", "") conversations = sample.get("conversations", []) if not image_path or not conversations: return # Find text response from GPT text = "" for conv in conversations: if conv.get("from") == "gpt": text = conv.get("value", "") break if not text: return # Resolve image path (relative to base_path) full_image_path = base_path / image_path if not full_image_path.exists(): # Try different relative paths for possible_base in [base_path, base_path.parent, base_path.parent.parent]: test_path = possible_base / image_path if test_path.exists(): full_image_path = test_path break if full_image_path.exists(): self.samples.append({ "image_path": str(full_image_path), "text": text.strip(), "id": sample.get("id", f"sample_{len(self.samples)}") }) except Exception as e: logger.warning(f"Error processing sample: {e}") def _get_unique_images(self): """Get count of unique images""" return len(set(sample["image_path"] for sample in self.samples)) def __len__(self): return len(self.samples) def __getitem__(self, idx): sample = self.samples[idx] # Load and transform image try: image = Image.open(sample["image_path"]).convert("RGB") if self.transform: image = self.transform(image) except Exception as e: logger.warning(f"Error loading image {sample['image_path']}: {e}") # Return a black image as fallback image = torch.zeros(3, 224, 224) # Tokenize text text = sample["text"] if self.tokenizer: tokens = self.tokenizer( text, max_length=self.max_length, padding='max_length', truncation=True, return_tensors='pt' ) return { "image": image, "input_ids": tokens["input_ids"].squeeze(), "attention_mask": tokens["attention_mask"].squeeze(), "text": text } else: return { "image": image, "text": text } class VisionEncoder(nn.Module): """Flexible vision encoder supporting ViT and ResNet architectures""" def __init__(self, model_name: str, embed_dim: int = 512, pretrained: bool = True): super().__init__() self.model_name = model_name self.embed_dim = embed_dim self.expected_image_size = 224 # Default # Try to determine architecture type if any(arch in model_name.lower() for arch in ["vit", "deit", "swin", "beit", "cait"]): self._setup_vit(model_name, pretrained) elif "resnet" in model_name.lower(): self._setup_resnet(model_name, pretrained) else: # 🔧 GENERIC TIMM MODEL LOADING self._setup_generic_timm(model_name, pretrained) # Projection head self.projection = nn.Linear(self.feature_dim, embed_dim) def _setup_generic_timm(self, model_name: str, pretrained: bool): """Setup any TIMM model generically""" try: self.backbone = timm.create_model( model_name, pretrained=pretrained, num_classes=0 # Remove classification head ) # Auto-detect input size and feature dimension self.feature_dim = None test_sizes = [224, 288, 336, 384, 448, 512] for test_size in test_sizes: try: with torch.no_grad(): dummy_input = torch.randn(1, 3, test_size, test_size) features = self.backbone(dummy_input) # Handle different output formats if len(features.shape) > 2: features = features.view(features.size(0), -1) self.feature_dim = features.shape[1] self.expected_image_size = test_size logger.info(f"Generic model {model_name} expects {test_size}x{test_size} → {self.feature_dim}D") break except Exception: continue if self.feature_dim is None: raise Exception("Could not determine model specifications") except Exception as e: logger.error(f"Failed to load {model_name}: {e}") raise def _setup_vit(self, model_name: str, pretrained: bool): """Setup Vision Transformer - works with any TIMM ViT model""" # Known mappings for convenience vit_mapping = { "vit-b/16": "vit_base_patch16_224", "vit-b/32": "vit_base_patch32_224", "vit-l/14": "vit_large_patch14_224", "vit-l/14@336": "vit_large_patch14_clip_336", "vit-h/14": "vit_huge_patch14_clip_378" } # Use mapping if available, otherwise use model name directly timm_name = vit_mapping.get(model_name.lower(), model_name) try: self.backbone = timm.create_model( timm_name, pretrained=pretrained, num_classes=0 ) # 🔧 AUTO-DETECT input size by trying common sizes self.feature_dim = None test_sizes = [224, 336, 378, 384, 512] # Common ViT sizes for test_size in test_sizes: try: with torch.no_grad(): dummy_input = torch.randn(1, 3, test_size, test_size) features = self.backbone(dummy_input) self.feature_dim = features.shape[1] self.expected_image_size = test_size logger.info(f"Model {timm_name} expects {test_size}x{test_size} input") break except Exception: continue if self.feature_dim is None: raise Exception("Could not determine input size for model") except Exception as e: logger.warning(f"Failed to load {timm_name}: {e}") logger.warning("Falling back to basic ViT") self.backbone = timm.create_model("vit_base_patch16_224", pretrained=pretrained, num_classes=0) self.feature_dim = 768 self.expected_image_size = 224 def _setup_resnet(self, model_name: str, pretrained: bool): """Setup ResNet""" if "resnet50" in model_name.lower(): self.backbone = resnet50(pretrained=pretrained) elif "resnet101" in model_name.lower(): self.backbone = resnet101(pretrained=pretrained) else: self.backbone = resnet50(pretrained=pretrained) # Remove classification head self.backbone = nn.Sequential(*list(self.backbone.children())[:-1]) self.feature_dim = 2048 # ResNet feature dimension def forward(self, x): features = self.backbone(x) if len(features.shape) > 2: features = features.view(features.size(0), -1) return self.projection(features) class TextEncoder(nn.Module): """Flexible text encoder supporting various transformer models""" def __init__(self, model_name: str = "bert-base-uncased", embed_dim: int = 512, max_length: int = 77, pretrained: bool = True): super().__init__() self.model_name = model_name self.embed_dim = embed_dim self.max_length = max_length if not pretrained: # Initialize from scratch if "bert" in model_name: from transformers import BertConfig config = BertConfig(vocab_size=30522, max_position_embeddings=max_length) self.transformer = BertModel(config) else: self.transformer = AutoModel.from_pretrained(model_name, ignore_mismatched_sizes=True) else: self.transformer = AutoModel.from_pretrained(model_name) # Get hidden dimension self.hidden_dim = self.transformer.config.hidden_size # Projection head self.projection = nn.Linear(self.hidden_dim, embed_dim) def forward(self, input_ids, attention_mask=None): outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask) # Use [CLS] token or mean pooling if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: features = outputs.pooler_output else: # Mean pooling over sequence length features = outputs.last_hidden_state.mean(dim=1) return self.projection(features) class CLIPModel(nn.Module): """CLIP model with contrastive learning""" def __init__(self, vision_model: str, text_model: str, embed_dim: int = 512, temperature: float = 0.07, vision_pretrained: bool = True, text_pretrained: bool = True): super().__init__() self.vision_encoder = VisionEncoder(vision_model, embed_dim, vision_pretrained) self.text_encoder = TextEncoder(text_model, embed_dim, pretrained=text_pretrained) # Temperature parameter for contrastive loss self.temperature = nn.Parameter(torch.tensor(temperature)) def forward(self, images, input_ids, attention_mask=None): # Encode images and text image_features = self.vision_encoder(images) text_features = self.text_encoder(input_ids, attention_mask) # Normalize features image_features = F.normalize(image_features, p=2, dim=1) text_features = F.normalize(text_features, p=2, dim=1) return image_features, text_features def compute_loss(self, image_features, text_features): """Compute contrastive loss""" batch_size = image_features.shape[0] # Compute similarity matrix logits = torch.matmul(image_features, text_features.T) / self.temperature # Labels are diagonal (each image matches its corresponding text) labels = torch.arange(batch_size, device=logits.device) # Compute cross-entropy loss for both directions loss_img = F.cross_entropy(logits, labels) loss_txt = F.cross_entropy(logits.T, labels) return (loss_img + loss_txt) / 2 class SigLIPModel(nn.Module): """SigLIP model with sigmoid loss instead of contrastive loss""" def __init__(self, vision_model: str, text_model: str, embed_dim: int = 512, temperature: float = 0.07, vision_pretrained: bool = True, text_pretrained: bool = True): super().__init__() self.vision_encoder = VisionEncoder(vision_model, embed_dim, vision_pretrained) self.text_encoder = TextEncoder(text_model, embed_dim, pretrained=text_pretrained) # Temperature parameter self.temperature = nn.Parameter(torch.tensor(temperature)) def forward(self, images, input_ids, attention_mask=None): # Encode images and text image_features = self.vision_encoder(images) text_features = self.text_encoder(input_ids, attention_mask) # Normalize features image_features = F.normalize(image_features, p=2, dim=1) text_features = F.normalize(text_features, p=2, dim=1) return image_features, text_features def compute_loss(self, image_features, text_features): """Compute SigLIP loss""" batch_size = image_features.shape[0] # Compute similarity matrix logits = torch.matmul(image_features, text_features.T) / self.temperature # Create positive and negative labels labels = torch.eye(batch_size, device=logits.device) labels = labels * 2 - 1 # Convert to -1/1 labels # SigLIP loss: -log(sigmoid(z_i * y_i)) loss = -F.logsigmoid(logits * labels).mean() return loss class PaveCLIPTrainer: """Complete training framework for PaveCLIP""" def __init__(self, config: Dict): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.distributed = False self.rank = 0 # Setup distributed training if specified if config.get("distributed", False): self._setup_distributed() # Initialize model self._setup_model() # Setup data self._setup_data() # Setup optimization self._setup_optimization() # Setup logging if config.get("wandb", False) and (not self.distributed or self.rank == 0): wandb.init(project="paveclip", config=config) def _setup_distributed(self): """Setup distributed training""" self.distributed = True self.rank = int(os.environ.get("LOCAL_RANK", 0)) self.world_size = int(os.environ.get("WORLD_SIZE", 1)) dist.init_process_group(backend="nccl") torch.cuda.set_device(self.rank) self.device = torch.device(f"cuda:{self.rank}") logger.info(f"Initialized distributed training: rank {self.rank}/{self.world_size}") def _setup_model(self): """Initialize the model""" model_type = self.config.get("model_type", "clip").lower() if model_type == "clip": self.model = CLIPModel( vision_model=self.config["vision_model"], text_model=self.config["text_model"], embed_dim=self.config.get("embed_dim", 512), temperature=self.config.get("temperature", 0.07), vision_pretrained=self.config.get("vision_pretrained", True), text_pretrained=self.config.get("text_pretrained", True) ) elif model_type == "siglip": self.model = SigLIPModel( vision_model=self.config["vision_model"], text_model=self.config["text_model"], embed_dim=self.config.get("embed_dim", 512), temperature=self.config.get("temperature", 0.07), vision_pretrained=self.config.get("vision_pretrained", True), text_pretrained=self.config.get("text_pretrained", True) ) else: raise ValueError(f"Unsupported model type: {model_type}") self.model = self.model.to(self.device) # Wrap with DDP for distributed training if hasattr(self, 'distributed') and self.distributed: self.model = DDP(self.model, device_ids=[self.rank]) def _setup_data(self): """Setup data loaders""" # Image transforms if "vit" in self.config["vision_model"].lower(): image_size = 336 if "@336" in self.config["vision_model"] else 224 else: image_size = 224 # Pavement-specific augmentations for robustness train_transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(degrees=15), # Roads can be at angles transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05), transforms.RandomGrayscale(p=0.1), # Some pavement images are grayscale transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize((image_size, image_size)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # Tokenizer from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(self.config["text_model"]) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Dataset train_dataset = PavementDataset( self.config["data_dir"], transform=train_transform, tokenizer=self.tokenizer, max_length=self.config.get("max_length", 77) ) # Split for validation if specified if self.config.get("val_split", 0.1) > 0: val_size = int(len(train_dataset) * self.config["val_split"]) train_size = len(train_dataset) - val_size train_dataset, val_dataset = torch.utils.data.random_split( train_dataset, [train_size, val_size] ) val_dataset.dataset.transform = val_transform else: val_dataset = None # Data loaders train_sampler = DistributedSampler(train_dataset) if hasattr(self, 'distributed') and self.distributed else None self.train_loader = DataLoader( train_dataset, batch_size=self.config["batch_size"], shuffle=(train_sampler is None), sampler=train_sampler, num_workers=self.config.get("num_workers", 4), pin_memory=True, drop_last=True ) if val_dataset: val_sampler = DistributedSampler(val_dataset) if hasattr(self, 'distributed') and self.distributed else None self.val_loader = DataLoader( val_dataset, batch_size=self.config["batch_size"], shuffle=False, sampler=val_sampler, num_workers=self.config.get("num_workers", 4), pin_memory=True ) else: self.val_loader = None logger.info(f"Training samples: {len(train_dataset)}") if val_dataset: logger.info(f"Validation samples: {len(val_dataset)}") def _setup_optimization(self): """Setup optimizer and scheduler""" # Pavement-specific optimization strategy # Different learning rates for vision and text encoders vision_params = [] text_params = [] other_params = [] model = self.model.module if hasattr(self.model, 'module') else self.model for name, param in model.named_parameters(): if 'vision_encoder' in name: vision_params.append(param) elif 'text_encoder' in name: text_params.append(param) else: other_params.append(param) # Different learning rates for different components param_groups = [ {'params': vision_params, 'lr': self.config["learning_rate"] * 0.1}, # Lower LR for vision {'params': text_params, 'lr': self.config["learning_rate"]}, # Standard LR for text {'params': other_params, 'lr': self.config["learning_rate"]} # Standard LR for others ] self.optimizer = torch.optim.AdamW( param_groups, weight_decay=self.config.get("weight_decay", 0.01) ) # Learning rate scheduler total_steps = len(self.train_loader) * self.config["epochs"] warmup_steps = int(total_steps * self.config.get("warmup_ratio", 0.1)) self.scheduler = torch.optim.lr_scheduler.OneCycleLR( self.optimizer, max_lr=[group['lr'] for group in param_groups], total_steps=total_steps, pct_start=warmup_steps / total_steps, anneal_strategy='cos' ) def train_epoch(self, epoch: int): """Train for one epoch""" self.model.train() if hasattr(self, 'distributed') and self.distributed: self.train_loader.sampler.set_epoch(epoch) total_loss = 0 num_batches = len(self.train_loader) pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}") if (not hasattr(self, 'distributed') or self.rank == 0) else self.train_loader for batch_idx, batch in enumerate(pbar): images = batch["image"].to(self.device, non_blocking=True) input_ids = batch["input_ids"].to(self.device, non_blocking=True) attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) # Forward pass image_features, text_features = self.model(images, input_ids, attention_mask) # Compute loss loss = self.model.module.compute_loss(image_features, text_features) if hasattr(self.model, 'module') else self.model.compute_loss(image_features, text_features) # Backward pass self.optimizer.zero_grad() loss.backward() # Gradient clipping for stability torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.optimizer.step() self.scheduler.step() total_loss += loss.item() # Update progress bar if hasattr(pbar, 'set_postfix'): pbar.set_postfix({ 'loss': f'{loss.item():.4f}', 'avg_loss': f'{total_loss/(batch_idx+1):.4f}', 'lr': f'{self.scheduler.get_last_lr()[0]:.2e}' }) # Log to wandb if self.config.get("wandb", False) and (not hasattr(self, 'distributed') or self.rank == 0): wandb.log({ "train_loss": loss.item(), "learning_rate": self.scheduler.get_last_lr()[0], "epoch": epoch, "step": epoch * num_batches + batch_idx }) return total_loss / num_batches def validate(self, epoch: int): """Validate the model""" if self.val_loader is None: return None self.model.eval() total_loss = 0 with torch.no_grad(): for batch in self.val_loader: images = batch["image"].to(self.device, non_blocking=True) input_ids = batch["input_ids"].to(self.device, non_blocking=True) attention_mask = batch["attention_mask"].to(self.device, non_blocking=True) # Forward pass image_features, text_features = self.model(images, input_ids, attention_mask) # Compute loss loss = self.model.module.compute_loss(image_features, text_features) if hasattr(self.model, 'module') else self.model.compute_loss(image_features, text_features) total_loss += loss.item() avg_loss = total_loss / len(self.val_loader) if self.config.get("wandb", False) and (not hasattr(self, 'distributed') or self.rank == 0): wandb.log({ "val_loss": avg_loss, "epoch": epoch }) return avg_loss def train(self): """Main training loop""" logger.info("Starting training...") best_val_loss = float('inf') for epoch in range(self.config["epochs"]): # Train train_loss = self.train_epoch(epoch) # Validate val_loss = self.validate(epoch) # Log epoch results if not hasattr(self, 'distributed') or self.rank == 0: logger.info(f"Epoch {epoch+1}/{self.config['epochs']}") logger.info(f"Train Loss: {train_loss:.4f}") if val_loss is not None: logger.info(f"Val Loss: {val_loss:.4f}") # Save checkpoint if (not hasattr(self, 'distributed') or self.rank == 0) and val_loss is not None and val_loss < best_val_loss: best_val_loss = val_loss self.save_checkpoint(epoch, is_best=True) # Regular checkpoint if (epoch + 1) % self.config.get("save_every", 10) == 0: if not hasattr(self, 'distributed') or self.rank == 0: self.save_checkpoint(epoch, is_best=False) def save_checkpoint(self, epoch: int, is_best: bool = False): """Save model checkpoint""" model_state = self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict() checkpoint = { 'epoch': epoch, 'model_state_dict': model_state, 'optimizer_state_dict': self.optimizer.state_dict(), 'config': self.config } filename = f"paveclip_epoch_{epoch+1}.pt" if is_best: filename = "paveclip_best.pt" save_path = Path(self.config["output_dir"]) / filename save_path.parent.mkdir(parents=True, exist_ok=True) torch.save(checkpoint, save_path) logger.info(f"Saved checkpoint: {save_path}") class PaveCLIPEvaluator: """Evaluation utilities for PaveCLIP""" def __init__(self, model_path: str, config: Dict): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.config = config # Load model checkpoint = torch.load(model_path, map_location=self.device) model_config = checkpoint['config'] # Initialize model if model_config.get("model_type", "clip").lower() == "clip": self.model = CLIPModel( vision_model=model_config["vision_model"], text_model=model_config["text_model"], embed_dim=model_config.get("embed_dim", 512) ) else: self.model = SigLIPModel( vision_model=model_config["vision_model"], text_model=model_config["text_model"], embed_dim=model_config.get("embed_dim", 512) ) self.model.load_state_dict(checkpoint['model_state_dict']) self.model = self.model.to(self.device) self.model.eval() # Setup tokenizer and transforms from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(model_config["text_model"]) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token # Image transforms #image_size = 336 if "@336" in model_config["vision_model"] else 224 expected = getattr(self.model.vision_encoder, "expected_image_size", 224) self.transform = transforms.Compose([ transforms.Resize((expected, expected)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.image_size = expected # keep for later use def encode_images(self, image_paths: List[str]) -> torch.Tensor: """Encode list of images""" features = [] with torch.no_grad(): for img_path in image_paths: image = Image.open(img_path).convert("RGB") image = self.transform(image).unsqueeze(0).to(self.device) img_features, _ = self.model(image, torch.zeros(1, 1).long().to(self.device)) features.append(img_features.cpu()) return torch.cat(features, dim=0) def encode_texts(self, texts: List[str]) -> torch.Tensor: """Encode list of texts""" tokens = self.tokenizer( texts, max_length=77, padding='max_length', truncation=True, return_tensors='pt' ) # with torch.no_grad(): # tokens = {k: v.to(self.device) for k, v in tokens.items()} # dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device) # _, text_features = self.model(dummy_images, tokens["input_ids"], tokens["attention_mask"]) # In PaveCLIPEvaluator.encode_texts with torch.no_grad(): tokens = {k: v.to(self.device) for k, v in tokens.items()} text_features = self.model.text_encoder(tokens["input_ids"], tokens["attention_mask"]) text_features = F.normalize(text_features, p=2, dim=1) return text_features.cpu() def zero_shot_classification(self, image_paths: List[str], class_texts: List[str]) -> Dict: """Perform zero-shot classification""" logger.info("Performing zero-shot classification...") # Encode images and texts image_features = self.encode_images(image_paths) text_features = self.encode_texts(class_texts) # Compute similarities similarities = torch.matmul(image_features, text_features.T) predictions = similarities.argmax(dim=1) # Compute accuracy if ground truth is available results = { "predictions": predictions.tolist(), "similarities": similarities.tolist(), "class_texts": class_texts } return results def image_retrieval(self, query_text: str, image_paths: List[str], top_k: int = 5) -> List[Tuple[str, float]]: """Retrieve top-k images for a text query""" logger.info(f"Retrieving top-{top_k} images for query: '{query_text}'") # Encode query and images text_features = self.encode_texts([query_text]) image_features = self.encode_images(image_paths) # Compute similarities similarities = torch.matmul(text_features, image_features.T).squeeze() # Get top-k results top_k_indices = similarities.argsort(descending=True)[:top_k] results = [] for idx in top_k_indices: results.append((image_paths[idx.item()], similarities[idx.item()].item())) return results def main(): """Main training script""" parser = argparse.ArgumentParser(description="Train PaveCLIP model") # Model arguments parser.add_argument("--model_type", default="clip", choices=["clip", "siglip"], help="Model type to train") parser.add_argument("--vision_model", default="vit-b/16", help="Vision encoder (e.g., vit-b/16, vit-l/14@336, resnet50)") parser.add_argument("--text_model", default="bert-base-uncased", help="Text encoder (e.g., bert-base-uncased, roberta-base)") parser.add_argument("--embed_dim", type=int, default=512, help="Embedding dimension") parser.add_argument("--vision_pretrained", action="store_true", help="Use pretrained vision encoder") parser.add_argument("--text_pretrained", action="store_true", help="Use pretrained text encoder") # Data arguments parser.add_argument("--data_dir", required=True, help="Path to Pavement_Pretraining_Data directory") parser.add_argument("--val_split", type=float, default=0.1, help="Validation split ratio") parser.add_argument("--max_length", type=int, default=77, help="Maximum text length") # Training arguments parser.add_argument("--batch_size", type=int, default=64, help="Batch size") parser.add_argument("--epochs", type=int, default=50, help="Number of epochs") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay") parser.add_argument("--temperature", type=float, default=0.07, help="Temperature parameter") parser.add_argument("--warmup_ratio", type=float, default=0.1, help="Warmup ratio") # System arguments parser.add_argument("--num_workers", type=int, default=4, help="Number of data loader workers") parser.add_argument("--output_dir", default="./checkpoints", help="Output directory for checkpoints") parser.add_argument("--save_every", type=int, default=10, help="Save checkpoint every N epochs") parser.add_argument("--wandb", action="store_true", help="Use Weights & Biases logging") parser.add_argument("--distributed", action="store_true", help="Enable distributed training") args = parser.parse_args() # Convert args to config dict config = vars(args) # Initialize trainer trainer = PaveCLIPTrainer(config) # Start training trainer.train() # Cleanup distributed training if config.get("distributed", False): dist.destroy_process_group() if __name__ == "__main__": main() # python paveclip_training.py \ # --vision_model vit-b/16 \ # --text_model distilbert-base-uncased \ # --vision_pretrained \ # --text_pretrained \ # --data_dir ./Pavement_Pretraining_Data \ # --batch_size 64 \ # --epochs 100 \ # --wandb # torchrun --nproc_per_node=4 paveclip_training.py \ # --distributed \ # [other args]