blessing.agyeikyem commited on
Commit
4dc7e79
·
1 Parent(s): 9cc7dcf

Deploy space without large model file

Browse files
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from PIL import Image
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import pandas as pd
9
+ import io
10
+ import base64
11
+ from sklearn.manifold import TSNE
12
+ from sklearn.decomposition import PCA
13
+ import plotly.express as px
14
+ import plotly.graph_objects as go
15
+ from datetime import datetime
16
+ import json
17
+ import os
18
+ import tempfile
19
+ import zipfile
20
+ import huggingface_hub
21
+ from huggingface_hub import hf_hub_download
22
+
23
+ # Import your PaveCLIP model (adjust import based on your model structure)
24
+ from paveclip_training import PaveCLIPEvaluator
25
+
26
+ # Download model from Hugging Face Hub if needed
27
+ def download_model():
28
+ """Download model from Hugging Face Hub"""
29
+ try:
30
+ # Replace with your actual model repository
31
+ model_path = hf_hub_download(
32
+ repo_id="your-username/paveclip-model", # Update this
33
+ filename="paveclip_best.pt"
34
+ )
35
+ return model_path
36
+ except:
37
+ # Fallback to local path if available
38
+ return "./paveclip_best.pt"
39
+
40
+ def download_model_from_hf():
41
+ """Download model from separate HF model repository"""
42
+ try:
43
+ print("📥 Downloading PaveCLIP model...")
44
+ model_path = hf_hub_download(
45
+ repo_id="Blessing988/paveclip-model", # Your model repo
46
+ filename="paveclip_best.pt",
47
+ cache_dir="./models"
48
+ )
49
+ print("✅ Model downloaded successfully!")
50
+ return model_path
51
+ except Exception as e:
52
+ print(f"❌ Download failed: {e}")
53
+ return None
54
+
55
+ class PavementAnalysisApp:
56
+ def __init__(self, model_path):
57
+ """Initialize the Pavement Analysis App"""
58
+ model_path = download_model_from_hf()
59
+
60
+ if model_path:
61
+ self.evaluator = PaveCLIPEvaluator(model_path, {})
62
+
63
+ # Pavement-specific class definitions
64
+ self.distress_classes = [
65
+ "pavement with longitudinal crack",
66
+ "pavement with lateral crack",
67
+ "pavement with alligator crack",
68
+ "pavement with pothole",
69
+ "road with patching"
70
+ ]
71
+
72
+ self.material_classes = [
73
+ "asphalt road surface",
74
+ "wet asphalt surface",
75
+ "wet concrete surface",
76
+ "concrete road surface",
77
+ "gravel road surface",
78
+ "dry and smooth asphalt surface"
79
+ ]
80
+
81
+ self.condition_classes = [
82
+ "smooth road surface",
83
+ "slightly uneven road surface",
84
+ "severely damaged road surface",
85
+ "well-maintained pavement",
86
+ "deteriorated pavement"
87
+ ]
88
+
89
+ # Store embeddings for comparison
90
+ self.image_embeddings = {}
91
+ self.text_embeddings = {}
92
+
93
+ def analyze_single_image(self, image, analysis_type="all"):
94
+ """Analyze a single uploaded image"""
95
+ if image is None:
96
+ return "Please upload an image first.", {}, {}, {}
97
+
98
+ # Save temporary image
99
+ temp_path = "temp_image.jpg"
100
+ image.save(temp_path)
101
+
102
+ results = {}
103
+
104
+ try:
105
+ if analysis_type in ["distress", "all"]:
106
+ distress_result = self.evaluator.zero_shot_classification([temp_path], self.distress_classes)
107
+ results["distress"] = self._format_results(distress_result, self.distress_classes)
108
+
109
+ if analysis_type in ["material", "all"]:
110
+ material_result = self.evaluator.zero_shot_classification([temp_path], self.material_classes)
111
+ results["material"] = self._format_results(material_result, self.material_classes)
112
+
113
+ if analysis_type in ["condition", "all"]:
114
+ condition_result = self.evaluator.zero_shot_classification([temp_path], self.condition_classes)
115
+ results["condition"] = self._format_results(condition_result, self.condition_classes)
116
+
117
+ # Generate summary text
118
+ summary = self._generate_summary(results)
119
+
120
+ # Clean up
121
+ os.remove(temp_path)
122
+
123
+ return summary, results.get("distress", {}), results.get("material", {}), results.get("condition", {})
124
+
125
+ except Exception as e:
126
+ os.remove(temp_path) if os.path.exists(temp_path) else None
127
+ return f"Error analyzing image: {str(e)}", {}, {}, {}
128
+
129
+ # ... (Include all other methods from the main app class)
130
+
131
+ def _format_results(self, result, class_names):
132
+ """Format classification results for display"""
133
+ predictions = result["predictions"]
134
+ similarities = result["similarities"]
135
+
136
+ formatted = {}
137
+ for i, class_name in enumerate(class_names):
138
+ confidence = float(similarities[0][i])
139
+ formatted[class_name] = confidence
140
+
141
+ return formatted
142
+
143
+ def _generate_summary(self, results):
144
+ """Generate text summary of analysis"""
145
+ summary_parts = ["🔍 **Pavement Analysis Results**\n"]
146
+
147
+ for category, result in results.items():
148
+ if result:
149
+ best_match = max(result.items(), key=lambda x: x[1])
150
+ category_name = category.capitalize()
151
+ summary_parts.append(f"**{category_name}:** {best_match[0]} (confidence: {best_match[1]:.3f})")
152
+
153
+ return "\n".join(summary_parts)
154
+
155
+ def create_demo():
156
+ """Create the Gradio demo"""
157
+
158
+ # Download/load model
159
+ model_path = download_model()
160
+ app = PavementAnalysisApp(model_path)
161
+
162
+ # Create interface
163
+ with gr.Blocks(title="🛣️ PaveCLIP: Advanced Pavement Analysis") as demo:
164
+
165
+ gr.Markdown("""
166
+ # 🛣️ PaveCLIP: Advanced Pavement Analysis Platform
167
+
168
+ **Professional pavement condition assessment using state-of-the-art computer vision**
169
+
170
+ Upload pavement images to get comprehensive analysis including distress detection,
171
+ material classification, and condition assessment.
172
+ """)
173
+
174
+ with gr.Tab("🖼️ Single Image Analysis"):
175
+ with gr.Row():
176
+ with gr.Column():
177
+ input_image = gr.Image(type="pil", label="Upload Pavement Image")
178
+ analysis_type = gr.Radio(
179
+ choices=["all", "distress", "material", "condition"],
180
+ value="all",
181
+ label="Analysis Type"
182
+ )
183
+ analyze_btn = gr.Button("🔍 Analyze Image", variant="primary")
184
+
185
+ with gr.Column():
186
+ analysis_summary = gr.Markdown(label="Analysis Summary")
187
+
188
+ with gr.Row():
189
+ distress_output = gr.JSON(label="Distress Classification")
190
+ material_output = gr.JSON(label="Material Classification")
191
+ condition_output = gr.JSON(label="Condition Assessment")
192
+
193
+ analyze_btn.click(
194
+ fn=app.analyze_single_image,
195
+ inputs=[input_image, analysis_type],
196
+ outputs=[analysis_summary, distress_output, material_output, condition_output]
197
+ )
198
+
199
+ # Add examples
200
+ gr.Examples(
201
+ examples=[
202
+ ["examples/cracked_pavement.jpg", "all"],
203
+ ["examples/pothole.jpg", "distress"],
204
+ ["examples/smooth_asphalt.jpg", "condition"]
205
+ ],
206
+ inputs=[input_image, analysis_type],
207
+ outputs=[analysis_summary, distress_output, material_output, condition_output],
208
+ fn=app.analyze_single_image,
209
+ cache_examples=True
210
+ )
211
+
212
+ return demo
213
+
214
+ if __name__ == "__main__":
215
+ demo = create_demo()
216
+ demo.launch()
examples/202202122309381-dry-asphalt-severe.jpg ADDED
examples/202202122342019-dry-concrete-slight.jpg ADDED
examples/202205031731377-wet-concrete-severe.jpg ADDED
paveclip_training.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PaveCLIP: Complete CLIP Training Framework for Pavement Data
3
+ Supports ViT/ResNet encoders, BERT/custom text encoders, SigLIP, Multi-GPU training
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torch.distributed as dist
12
+ from torch.utils.data import Dataset, DataLoader
13
+ from torch.utils.data.distributed import DistributedSampler
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ import torchvision.transforms as transforms
16
+ from torchvision.models import resnet50, resnet101
17
+ import timm
18
+ from transformers import AutoTokenizer, AutoModel, BertModel, RobertaModel
19
+ from PIL import Image
20
+ import numpy as np
21
+ from pathlib import Path
22
+ import matplotlib.pyplot as plt
23
+ from sklearn.metrics.pairwise import cosine_similarity
24
+ import logging
25
+ from typing import Dict, List, Tuple, Optional, Union
26
+ import argparse
27
+ import time
28
+ import wandb
29
+ from tqdm import tqdm
30
+ import warnings
31
+ warnings.filterwarnings("ignore")
32
+
33
+ # Setup logging
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__name__)
36
+
37
+ class PavementDataset(Dataset):
38
+ """
39
+ Dataset loader for pavement pretraining data with complex folder structure
40
+ """
41
+
42
+ def __init__(self, data_dir: str, transform=None, tokenizer=None, max_length=77):
43
+ self.data_dir = Path(data_dir)
44
+ self.transform = transform
45
+ self.tokenizer = tokenizer
46
+ self.max_length = max_length
47
+ self.samples = []
48
+
49
+ logger.info(f"Loading dataset from {data_dir}")
50
+ self._load_dataset()
51
+ logger.info(f"Loaded {len(self.samples)} samples from {self._get_unique_images()} unique images")
52
+
53
+ def _load_dataset(self):
54
+ """Load all JSON files and collect image-text pairs"""
55
+ json_files = list(self.data_dir.rglob("*.json"))
56
+
57
+ for json_file in json_files:
58
+ try:
59
+ with open(json_file, 'r') as f:
60
+ data = json.load(f)
61
+
62
+ # Handle different JSON structures
63
+ if isinstance(data, list):
64
+ # List of samples
65
+ for item in data:
66
+ self._process_sample(item, json_file.parent)
67
+ elif isinstance(data, dict):
68
+ # Single sample or nested structure
69
+ if "conversations" in data:
70
+ self._process_sample(data, json_file.parent)
71
+ else:
72
+ # Check if it's a collection
73
+ for key, value in data.items():
74
+ if isinstance(value, dict) and "conversations" in value:
75
+ self._process_sample(value, json_file.parent)
76
+ elif isinstance(value, list):
77
+ for item in value:
78
+ if isinstance(item, dict) and "conversations" in item:
79
+ self._process_sample(item, json_file.parent)
80
+
81
+ except Exception as e:
82
+ logger.warning(f"Error loading {json_file}: {e}")
83
+
84
+ def _process_sample(self, sample: dict, base_path: Path):
85
+ """Process individual sample and extract image-text pair"""
86
+ try:
87
+ image_path = sample.get("image", "")
88
+ conversations = sample.get("conversations", [])
89
+
90
+ if not image_path or not conversations:
91
+ return
92
+
93
+ # Find text response from GPT
94
+ text = ""
95
+ for conv in conversations:
96
+ if conv.get("from") == "gpt":
97
+ text = conv.get("value", "")
98
+ break
99
+
100
+ if not text:
101
+ return
102
+
103
+ # Resolve image path (relative to base_path)
104
+ full_image_path = base_path / image_path
105
+ if not full_image_path.exists():
106
+ # Try different relative paths
107
+ for possible_base in [base_path, base_path.parent, base_path.parent.parent]:
108
+ test_path = possible_base / image_path
109
+ if test_path.exists():
110
+ full_image_path = test_path
111
+ break
112
+
113
+ if full_image_path.exists():
114
+ self.samples.append({
115
+ "image_path": str(full_image_path),
116
+ "text": text.strip(),
117
+ "id": sample.get("id", f"sample_{len(self.samples)}")
118
+ })
119
+
120
+ except Exception as e:
121
+ logger.warning(f"Error processing sample: {e}")
122
+
123
+ def _get_unique_images(self):
124
+ """Get count of unique images"""
125
+ return len(set(sample["image_path"] for sample in self.samples))
126
+
127
+ def __len__(self):
128
+ return len(self.samples)
129
+
130
+ def __getitem__(self, idx):
131
+ sample = self.samples[idx]
132
+
133
+ # Load and transform image
134
+ try:
135
+ image = Image.open(sample["image_path"]).convert("RGB")
136
+ if self.transform:
137
+ image = self.transform(image)
138
+ except Exception as e:
139
+ logger.warning(f"Error loading image {sample['image_path']}: {e}")
140
+ # Return a black image as fallback
141
+ image = torch.zeros(3, 224, 224)
142
+
143
+ # Tokenize text
144
+ text = sample["text"]
145
+ if self.tokenizer:
146
+ tokens = self.tokenizer(
147
+ text,
148
+ max_length=self.max_length,
149
+ padding='max_length',
150
+ truncation=True,
151
+ return_tensors='pt'
152
+ )
153
+ return {
154
+ "image": image,
155
+ "input_ids": tokens["input_ids"].squeeze(),
156
+ "attention_mask": tokens["attention_mask"].squeeze(),
157
+ "text": text
158
+ }
159
+ else:
160
+ return {
161
+ "image": image,
162
+ "text": text
163
+ }
164
+
165
+
166
+ class VisionEncoder(nn.Module):
167
+ """Flexible vision encoder supporting ViT and ResNet architectures"""
168
+
169
+ def __init__(self, model_name: str, embed_dim: int = 512, pretrained: bool = True):
170
+ super().__init__()
171
+ self.model_name = model_name
172
+ self.embed_dim = embed_dim
173
+ self.expected_image_size = 224 # Default
174
+
175
+ # Try to determine architecture type
176
+ if any(arch in model_name.lower() for arch in ["vit", "deit", "swin", "beit", "cait"]):
177
+ self._setup_vit(model_name, pretrained)
178
+ elif "resnet" in model_name.lower():
179
+ self._setup_resnet(model_name, pretrained)
180
+ else:
181
+ # 🔧 GENERIC TIMM MODEL LOADING
182
+ self._setup_generic_timm(model_name, pretrained)
183
+
184
+ # Projection head
185
+ self.projection = nn.Linear(self.feature_dim, embed_dim)
186
+
187
+ def _setup_generic_timm(self, model_name: str, pretrained: bool):
188
+ """Setup any TIMM model generically"""
189
+ try:
190
+ self.backbone = timm.create_model(
191
+ model_name,
192
+ pretrained=pretrained,
193
+ num_classes=0 # Remove classification head
194
+ )
195
+
196
+ # Auto-detect input size and feature dimension
197
+ self.feature_dim = None
198
+ test_sizes = [224, 288, 336, 384, 448, 512]
199
+
200
+ for test_size in test_sizes:
201
+ try:
202
+ with torch.no_grad():
203
+ dummy_input = torch.randn(1, 3, test_size, test_size)
204
+ features = self.backbone(dummy_input)
205
+
206
+ # Handle different output formats
207
+ if len(features.shape) > 2:
208
+ features = features.view(features.size(0), -1)
209
+
210
+ self.feature_dim = features.shape[1]
211
+ self.expected_image_size = test_size
212
+ logger.info(f"Generic model {model_name} expects {test_size}x{test_size} → {self.feature_dim}D")
213
+ break
214
+ except Exception:
215
+ continue
216
+
217
+ if self.feature_dim is None:
218
+ raise Exception("Could not determine model specifications")
219
+
220
+ except Exception as e:
221
+ logger.error(f"Failed to load {model_name}: {e}")
222
+ raise
223
+
224
+
225
+
226
+ def _setup_vit(self, model_name: str, pretrained: bool):
227
+ """Setup Vision Transformer - works with any TIMM ViT model"""
228
+
229
+ # Known mappings for convenience
230
+ vit_mapping = {
231
+ "vit-b/16": "vit_base_patch16_224",
232
+ "vit-b/32": "vit_base_patch32_224",
233
+ "vit-l/14": "vit_large_patch14_224",
234
+ "vit-l/14@336": "vit_large_patch14_clip_336",
235
+ "vit-h/14": "vit_huge_patch14_clip_378"
236
+ }
237
+
238
+ # Use mapping if available, otherwise use model name directly
239
+ timm_name = vit_mapping.get(model_name.lower(), model_name)
240
+
241
+ try:
242
+ self.backbone = timm.create_model(
243
+ timm_name,
244
+ pretrained=pretrained,
245
+ num_classes=0
246
+ )
247
+
248
+ # 🔧 AUTO-DETECT input size by trying common sizes
249
+ self.feature_dim = None
250
+ test_sizes = [224, 336, 378, 384, 512] # Common ViT sizes
251
+
252
+ for test_size in test_sizes:
253
+ try:
254
+ with torch.no_grad():
255
+ dummy_input = torch.randn(1, 3, test_size, test_size)
256
+ features = self.backbone(dummy_input)
257
+ self.feature_dim = features.shape[1]
258
+ self.expected_image_size = test_size
259
+ logger.info(f"Model {timm_name} expects {test_size}x{test_size} input")
260
+ break
261
+ except Exception:
262
+ continue
263
+
264
+ if self.feature_dim is None:
265
+ raise Exception("Could not determine input size for model")
266
+
267
+ except Exception as e:
268
+ logger.warning(f"Failed to load {timm_name}: {e}")
269
+ logger.warning("Falling back to basic ViT")
270
+ self.backbone = timm.create_model("vit_base_patch16_224", pretrained=pretrained, num_classes=0)
271
+ self.feature_dim = 768
272
+ self.expected_image_size = 224
273
+
274
+ def _setup_resnet(self, model_name: str, pretrained: bool):
275
+ """Setup ResNet"""
276
+ if "resnet50" in model_name.lower():
277
+ self.backbone = resnet50(pretrained=pretrained)
278
+ elif "resnet101" in model_name.lower():
279
+ self.backbone = resnet101(pretrained=pretrained)
280
+ else:
281
+ self.backbone = resnet50(pretrained=pretrained)
282
+
283
+ # Remove classification head
284
+ self.backbone = nn.Sequential(*list(self.backbone.children())[:-1])
285
+ self.feature_dim = 2048 # ResNet feature dimension
286
+
287
+ def forward(self, x):
288
+ features = self.backbone(x)
289
+ if len(features.shape) > 2:
290
+ features = features.view(features.size(0), -1)
291
+ return self.projection(features)
292
+
293
+
294
+ class TextEncoder(nn.Module):
295
+ """Flexible text encoder supporting various transformer models"""
296
+
297
+ def __init__(self, model_name: str = "bert-base-uncased", embed_dim: int = 512,
298
+ max_length: int = 77, pretrained: bool = True):
299
+ super().__init__()
300
+ self.model_name = model_name
301
+ self.embed_dim = embed_dim
302
+ self.max_length = max_length
303
+
304
+ if not pretrained:
305
+ # Initialize from scratch
306
+ if "bert" in model_name:
307
+ from transformers import BertConfig
308
+ config = BertConfig(vocab_size=30522, max_position_embeddings=max_length)
309
+ self.transformer = BertModel(config)
310
+ else:
311
+ self.transformer = AutoModel.from_pretrained(model_name,
312
+ ignore_mismatched_sizes=True)
313
+ else:
314
+ self.transformer = AutoModel.from_pretrained(model_name)
315
+
316
+ # Get hidden dimension
317
+ self.hidden_dim = self.transformer.config.hidden_size
318
+
319
+ # Projection head
320
+ self.projection = nn.Linear(self.hidden_dim, embed_dim)
321
+
322
+ def forward(self, input_ids, attention_mask=None):
323
+ outputs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)
324
+
325
+ # Use [CLS] token or mean pooling
326
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
327
+ features = outputs.pooler_output
328
+ else:
329
+ # Mean pooling over sequence length
330
+ features = outputs.last_hidden_state.mean(dim=1)
331
+
332
+ return self.projection(features)
333
+
334
+
335
+ class CLIPModel(nn.Module):
336
+ """CLIP model with contrastive learning"""
337
+
338
+ def __init__(self, vision_model: str, text_model: str, embed_dim: int = 512,
339
+ temperature: float = 0.07, vision_pretrained: bool = True,
340
+ text_pretrained: bool = True):
341
+ super().__init__()
342
+
343
+ self.vision_encoder = VisionEncoder(vision_model, embed_dim, vision_pretrained)
344
+ self.text_encoder = TextEncoder(text_model, embed_dim, pretrained=text_pretrained)
345
+
346
+ # Temperature parameter for contrastive loss
347
+ self.temperature = nn.Parameter(torch.tensor(temperature))
348
+
349
+ def forward(self, images, input_ids, attention_mask=None):
350
+ # Encode images and text
351
+ image_features = self.vision_encoder(images)
352
+ text_features = self.text_encoder(input_ids, attention_mask)
353
+
354
+ # Normalize features
355
+ image_features = F.normalize(image_features, p=2, dim=1)
356
+ text_features = F.normalize(text_features, p=2, dim=1)
357
+
358
+ return image_features, text_features
359
+
360
+ def compute_loss(self, image_features, text_features):
361
+ """Compute contrastive loss"""
362
+ batch_size = image_features.shape[0]
363
+
364
+ # Compute similarity matrix
365
+ logits = torch.matmul(image_features, text_features.T) / self.temperature
366
+
367
+ # Labels are diagonal (each image matches its corresponding text)
368
+ labels = torch.arange(batch_size, device=logits.device)
369
+
370
+ # Compute cross-entropy loss for both directions
371
+ loss_img = F.cross_entropy(logits, labels)
372
+ loss_txt = F.cross_entropy(logits.T, labels)
373
+
374
+ return (loss_img + loss_txt) / 2
375
+
376
+
377
+ class SigLIPModel(nn.Module):
378
+ """SigLIP model with sigmoid loss instead of contrastive loss"""
379
+
380
+ def __init__(self, vision_model: str, text_model: str, embed_dim: int = 512,
381
+ temperature: float = 0.07, vision_pretrained: bool = True,
382
+ text_pretrained: bool = True):
383
+ super().__init__()
384
+
385
+ self.vision_encoder = VisionEncoder(vision_model, embed_dim, vision_pretrained)
386
+ self.text_encoder = TextEncoder(text_model, embed_dim, pretrained=text_pretrained)
387
+
388
+ # Temperature parameter
389
+ self.temperature = nn.Parameter(torch.tensor(temperature))
390
+
391
+ def forward(self, images, input_ids, attention_mask=None):
392
+ # Encode images and text
393
+ image_features = self.vision_encoder(images)
394
+ text_features = self.text_encoder(input_ids, attention_mask)
395
+
396
+ # Normalize features
397
+ image_features = F.normalize(image_features, p=2, dim=1)
398
+ text_features = F.normalize(text_features, p=2, dim=1)
399
+
400
+ return image_features, text_features
401
+
402
+ def compute_loss(self, image_features, text_features):
403
+ """Compute SigLIP loss"""
404
+ batch_size = image_features.shape[0]
405
+
406
+ # Compute similarity matrix
407
+ logits = torch.matmul(image_features, text_features.T) / self.temperature
408
+
409
+ # Create positive and negative labels
410
+ labels = torch.eye(batch_size, device=logits.device)
411
+ labels = labels * 2 - 1 # Convert to -1/1 labels
412
+
413
+ # SigLIP loss: -log(sigmoid(z_i * y_i))
414
+ loss = -F.logsigmoid(logits * labels).mean()
415
+
416
+ return loss
417
+
418
+
419
+ class PaveCLIPTrainer:
420
+ """Complete training framework for PaveCLIP"""
421
+
422
+ def __init__(self, config: Dict):
423
+ self.config = config
424
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
425
+
426
+ self.distributed = False
427
+ self.rank = 0
428
+
429
+ # Setup distributed training if specified
430
+ if config.get("distributed", False):
431
+ self._setup_distributed()
432
+
433
+ # Initialize model
434
+ self._setup_model()
435
+
436
+ # Setup data
437
+ self._setup_data()
438
+
439
+ # Setup optimization
440
+ self._setup_optimization()
441
+
442
+ # Setup logging
443
+ if config.get("wandb", False) and (not self.distributed or self.rank == 0):
444
+ wandb.init(project="paveclip", config=config)
445
+
446
+ def _setup_distributed(self):
447
+ """Setup distributed training"""
448
+ self.distributed = True
449
+ self.rank = int(os.environ.get("LOCAL_RANK", 0))
450
+ self.world_size = int(os.environ.get("WORLD_SIZE", 1))
451
+
452
+ dist.init_process_group(backend="nccl")
453
+ torch.cuda.set_device(self.rank)
454
+ self.device = torch.device(f"cuda:{self.rank}")
455
+
456
+ logger.info(f"Initialized distributed training: rank {self.rank}/{self.world_size}")
457
+
458
+ def _setup_model(self):
459
+ """Initialize the model"""
460
+ model_type = self.config.get("model_type", "clip").lower()
461
+
462
+ if model_type == "clip":
463
+ self.model = CLIPModel(
464
+ vision_model=self.config["vision_model"],
465
+ text_model=self.config["text_model"],
466
+ embed_dim=self.config.get("embed_dim", 512),
467
+ temperature=self.config.get("temperature", 0.07),
468
+ vision_pretrained=self.config.get("vision_pretrained", True),
469
+ text_pretrained=self.config.get("text_pretrained", True)
470
+ )
471
+ elif model_type == "siglip":
472
+ self.model = SigLIPModel(
473
+ vision_model=self.config["vision_model"],
474
+ text_model=self.config["text_model"],
475
+ embed_dim=self.config.get("embed_dim", 512),
476
+ temperature=self.config.get("temperature", 0.07),
477
+ vision_pretrained=self.config.get("vision_pretrained", True),
478
+ text_pretrained=self.config.get("text_pretrained", True)
479
+ )
480
+ else:
481
+ raise ValueError(f"Unsupported model type: {model_type}")
482
+
483
+ self.model = self.model.to(self.device)
484
+
485
+ # Wrap with DDP for distributed training
486
+ if hasattr(self, 'distributed') and self.distributed:
487
+ self.model = DDP(self.model, device_ids=[self.rank])
488
+
489
+ def _setup_data(self):
490
+ """Setup data loaders"""
491
+ # Image transforms
492
+ if "vit" in self.config["vision_model"].lower():
493
+ image_size = 336 if "@336" in self.config["vision_model"] else 224
494
+ else:
495
+ image_size = 224
496
+
497
+ # Pavement-specific augmentations for robustness
498
+ train_transform = transforms.Compose([
499
+ transforms.Resize((image_size, image_size)),
500
+ transforms.RandomHorizontalFlip(p=0.5),
501
+ transforms.RandomRotation(degrees=15), # Roads can be at angles
502
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
503
+ transforms.RandomGrayscale(p=0.1), # Some pavement images are grayscale
504
+ transforms.ToTensor(),
505
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
506
+ ])
507
+
508
+ val_transform = transforms.Compose([
509
+ transforms.Resize((image_size, image_size)),
510
+ transforms.ToTensor(),
511
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
512
+ ])
513
+
514
+ # Tokenizer
515
+ from transformers import AutoTokenizer
516
+ self.tokenizer = AutoTokenizer.from_pretrained(self.config["text_model"])
517
+ if self.tokenizer.pad_token is None:
518
+ self.tokenizer.pad_token = self.tokenizer.eos_token
519
+
520
+ # Dataset
521
+ train_dataset = PavementDataset(
522
+ self.config["data_dir"],
523
+ transform=train_transform,
524
+ tokenizer=self.tokenizer,
525
+ max_length=self.config.get("max_length", 77)
526
+ )
527
+
528
+ # Split for validation if specified
529
+ if self.config.get("val_split", 0.1) > 0:
530
+ val_size = int(len(train_dataset) * self.config["val_split"])
531
+ train_size = len(train_dataset) - val_size
532
+ train_dataset, val_dataset = torch.utils.data.random_split(
533
+ train_dataset, [train_size, val_size]
534
+ )
535
+ val_dataset.dataset.transform = val_transform
536
+ else:
537
+ val_dataset = None
538
+
539
+ # Data loaders
540
+ train_sampler = DistributedSampler(train_dataset) if hasattr(self, 'distributed') and self.distributed else None
541
+
542
+ self.train_loader = DataLoader(
543
+ train_dataset,
544
+ batch_size=self.config["batch_size"],
545
+ shuffle=(train_sampler is None),
546
+ sampler=train_sampler,
547
+ num_workers=self.config.get("num_workers", 4),
548
+ pin_memory=True,
549
+ drop_last=True
550
+ )
551
+
552
+ if val_dataset:
553
+ val_sampler = DistributedSampler(val_dataset) if hasattr(self, 'distributed') and self.distributed else None
554
+ self.val_loader = DataLoader(
555
+ val_dataset,
556
+ batch_size=self.config["batch_size"],
557
+ shuffle=False,
558
+ sampler=val_sampler,
559
+ num_workers=self.config.get("num_workers", 4),
560
+ pin_memory=True
561
+ )
562
+ else:
563
+ self.val_loader = None
564
+
565
+ logger.info(f"Training samples: {len(train_dataset)}")
566
+ if val_dataset:
567
+ logger.info(f"Validation samples: {len(val_dataset)}")
568
+
569
+ def _setup_optimization(self):
570
+ """Setup optimizer and scheduler"""
571
+ # Pavement-specific optimization strategy
572
+ # Different learning rates for vision and text encoders
573
+ vision_params = []
574
+ text_params = []
575
+ other_params = []
576
+
577
+ model = self.model.module if hasattr(self.model, 'module') else self.model
578
+
579
+ for name, param in model.named_parameters():
580
+ if 'vision_encoder' in name:
581
+ vision_params.append(param)
582
+ elif 'text_encoder' in name:
583
+ text_params.append(param)
584
+ else:
585
+ other_params.append(param)
586
+
587
+ # Different learning rates for different components
588
+ param_groups = [
589
+ {'params': vision_params, 'lr': self.config["learning_rate"] * 0.1}, # Lower LR for vision
590
+ {'params': text_params, 'lr': self.config["learning_rate"]}, # Standard LR for text
591
+ {'params': other_params, 'lr': self.config["learning_rate"]} # Standard LR for others
592
+ ]
593
+
594
+ self.optimizer = torch.optim.AdamW(
595
+ param_groups,
596
+ weight_decay=self.config.get("weight_decay", 0.01)
597
+ )
598
+
599
+ # Learning rate scheduler
600
+ total_steps = len(self.train_loader) * self.config["epochs"]
601
+ warmup_steps = int(total_steps * self.config.get("warmup_ratio", 0.1))
602
+
603
+ self.scheduler = torch.optim.lr_scheduler.OneCycleLR(
604
+ self.optimizer,
605
+ max_lr=[group['lr'] for group in param_groups],
606
+ total_steps=total_steps,
607
+ pct_start=warmup_steps / total_steps,
608
+ anneal_strategy='cos'
609
+ )
610
+
611
+ def train_epoch(self, epoch: int):
612
+ """Train for one epoch"""
613
+ self.model.train()
614
+
615
+ if hasattr(self, 'distributed') and self.distributed:
616
+ self.train_loader.sampler.set_epoch(epoch)
617
+
618
+ total_loss = 0
619
+ num_batches = len(self.train_loader)
620
+
621
+ pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}") if (not hasattr(self, 'distributed') or self.rank == 0) else self.train_loader
622
+
623
+ for batch_idx, batch in enumerate(pbar):
624
+ images = batch["image"].to(self.device, non_blocking=True)
625
+ input_ids = batch["input_ids"].to(self.device, non_blocking=True)
626
+ attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
627
+
628
+ # Forward pass
629
+ image_features, text_features = self.model(images, input_ids, attention_mask)
630
+
631
+ # Compute loss
632
+ loss = self.model.module.compute_loss(image_features, text_features) if hasattr(self.model, 'module') else self.model.compute_loss(image_features, text_features)
633
+
634
+ # Backward pass
635
+ self.optimizer.zero_grad()
636
+ loss.backward()
637
+
638
+ # Gradient clipping for stability
639
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
640
+
641
+ self.optimizer.step()
642
+ self.scheduler.step()
643
+
644
+ total_loss += loss.item()
645
+
646
+ # Update progress bar
647
+ if hasattr(pbar, 'set_postfix'):
648
+ pbar.set_postfix({
649
+ 'loss': f'{loss.item():.4f}',
650
+ 'avg_loss': f'{total_loss/(batch_idx+1):.4f}',
651
+ 'lr': f'{self.scheduler.get_last_lr()[0]:.2e}'
652
+ })
653
+
654
+ # Log to wandb
655
+ if self.config.get("wandb", False) and (not hasattr(self, 'distributed') or self.rank == 0):
656
+ wandb.log({
657
+ "train_loss": loss.item(),
658
+ "learning_rate": self.scheduler.get_last_lr()[0],
659
+ "epoch": epoch,
660
+ "step": epoch * num_batches + batch_idx
661
+ })
662
+
663
+ return total_loss / num_batches
664
+
665
+ def validate(self, epoch: int):
666
+ """Validate the model"""
667
+ if self.val_loader is None:
668
+ return None
669
+
670
+ self.model.eval()
671
+ total_loss = 0
672
+
673
+ with torch.no_grad():
674
+ for batch in self.val_loader:
675
+ images = batch["image"].to(self.device, non_blocking=True)
676
+ input_ids = batch["input_ids"].to(self.device, non_blocking=True)
677
+ attention_mask = batch["attention_mask"].to(self.device, non_blocking=True)
678
+
679
+ # Forward pass
680
+ image_features, text_features = self.model(images, input_ids, attention_mask)
681
+
682
+ # Compute loss
683
+ loss = self.model.module.compute_loss(image_features, text_features) if hasattr(self.model, 'module') else self.model.compute_loss(image_features, text_features)
684
+ total_loss += loss.item()
685
+
686
+ avg_loss = total_loss / len(self.val_loader)
687
+
688
+ if self.config.get("wandb", False) and (not hasattr(self, 'distributed') or self.rank == 0):
689
+ wandb.log({
690
+ "val_loss": avg_loss,
691
+ "epoch": epoch
692
+ })
693
+
694
+ return avg_loss
695
+
696
+ def train(self):
697
+ """Main training loop"""
698
+ logger.info("Starting training...")
699
+
700
+ best_val_loss = float('inf')
701
+
702
+ for epoch in range(self.config["epochs"]):
703
+ # Train
704
+ train_loss = self.train_epoch(epoch)
705
+
706
+ # Validate
707
+ val_loss = self.validate(epoch)
708
+
709
+ # Log epoch results
710
+ if not hasattr(self, 'distributed') or self.rank == 0:
711
+ logger.info(f"Epoch {epoch+1}/{self.config['epochs']}")
712
+ logger.info(f"Train Loss: {train_loss:.4f}")
713
+ if val_loss is not None:
714
+ logger.info(f"Val Loss: {val_loss:.4f}")
715
+
716
+ # Save checkpoint
717
+ if (not hasattr(self, 'distributed') or self.rank == 0) and val_loss is not None and val_loss < best_val_loss:
718
+ best_val_loss = val_loss
719
+ self.save_checkpoint(epoch, is_best=True)
720
+
721
+ # Regular checkpoint
722
+ if (epoch + 1) % self.config.get("save_every", 10) == 0:
723
+ if not hasattr(self, 'distributed') or self.rank == 0:
724
+ self.save_checkpoint(epoch, is_best=False)
725
+
726
+ def save_checkpoint(self, epoch: int, is_best: bool = False):
727
+ """Save model checkpoint"""
728
+ model_state = self.model.module.state_dict() if hasattr(self.model, 'module') else self.model.state_dict()
729
+
730
+ checkpoint = {
731
+ 'epoch': epoch,
732
+ 'model_state_dict': model_state,
733
+ 'optimizer_state_dict': self.optimizer.state_dict(),
734
+ 'config': self.config
735
+ }
736
+
737
+ filename = f"paveclip_epoch_{epoch+1}.pt"
738
+ if is_best:
739
+ filename = "paveclip_best.pt"
740
+
741
+ save_path = Path(self.config["output_dir"]) / filename
742
+ save_path.parent.mkdir(parents=True, exist_ok=True)
743
+
744
+ torch.save(checkpoint, save_path)
745
+ logger.info(f"Saved checkpoint: {save_path}")
746
+
747
+
748
+ class PaveCLIPEvaluator:
749
+ """Evaluation utilities for PaveCLIP"""
750
+
751
+ def __init__(self, model_path: str, config: Dict):
752
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
753
+ self.config = config
754
+
755
+ # Load model
756
+ checkpoint = torch.load(model_path, map_location=self.device)
757
+ model_config = checkpoint['config']
758
+
759
+ # Initialize model
760
+ if model_config.get("model_type", "clip").lower() == "clip":
761
+ self.model = CLIPModel(
762
+ vision_model=model_config["vision_model"],
763
+ text_model=model_config["text_model"],
764
+ embed_dim=model_config.get("embed_dim", 512)
765
+ )
766
+ else:
767
+ self.model = SigLIPModel(
768
+ vision_model=model_config["vision_model"],
769
+ text_model=model_config["text_model"],
770
+ embed_dim=model_config.get("embed_dim", 512)
771
+ )
772
+
773
+ self.model.load_state_dict(checkpoint['model_state_dict'])
774
+ self.model = self.model.to(self.device)
775
+ self.model.eval()
776
+
777
+ # Setup tokenizer and transforms
778
+ from transformers import AutoTokenizer
779
+ self.tokenizer = AutoTokenizer.from_pretrained(model_config["text_model"])
780
+ if self.tokenizer.pad_token is None:
781
+ self.tokenizer.pad_token = self.tokenizer.eos_token
782
+
783
+ # Image transforms
784
+ #image_size = 336 if "@336" in model_config["vision_model"] else 224
785
+ expected = getattr(self.model.vision_encoder, "expected_image_size", 224)
786
+
787
+ self.transform = transforms.Compose([
788
+ transforms.Resize((expected, expected)),
789
+ transforms.ToTensor(),
790
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
791
+ ])
792
+
793
+ self.image_size = expected # keep for later use
794
+
795
+
796
+ def encode_images(self, image_paths: List[str]) -> torch.Tensor:
797
+ """Encode list of images"""
798
+ features = []
799
+
800
+ with torch.no_grad():
801
+ for img_path in image_paths:
802
+ image = Image.open(img_path).convert("RGB")
803
+ image = self.transform(image).unsqueeze(0).to(self.device)
804
+
805
+ img_features, _ = self.model(image, torch.zeros(1, 1).long().to(self.device))
806
+ features.append(img_features.cpu())
807
+
808
+ return torch.cat(features, dim=0)
809
+
810
+ def encode_texts(self, texts: List[str]) -> torch.Tensor:
811
+ """Encode list of texts"""
812
+ tokens = self.tokenizer(
813
+ texts,
814
+ max_length=77,
815
+ padding='max_length',
816
+ truncation=True,
817
+ return_tensors='pt'
818
+ )
819
+
820
+ # with torch.no_grad():
821
+ # tokens = {k: v.to(self.device) for k, v in tokens.items()}
822
+ # dummy_images = torch.zeros(len(texts), 3, 224, 224).to(self.device)
823
+ # _, text_features = self.model(dummy_images, tokens["input_ids"], tokens["attention_mask"])
824
+
825
+ # In PaveCLIPEvaluator.encode_texts
826
+ with torch.no_grad():
827
+ tokens = {k: v.to(self.device) for k, v in tokens.items()}
828
+ text_features = self.model.text_encoder(tokens["input_ids"], tokens["attention_mask"])
829
+ text_features = F.normalize(text_features, p=2, dim=1)
830
+ return text_features.cpu()
831
+
832
+ def zero_shot_classification(self, image_paths: List[str], class_texts: List[str]) -> Dict:
833
+ """Perform zero-shot classification"""
834
+ logger.info("Performing zero-shot classification...")
835
+
836
+ # Encode images and texts
837
+ image_features = self.encode_images(image_paths)
838
+ text_features = self.encode_texts(class_texts)
839
+
840
+ # Compute similarities
841
+ similarities = torch.matmul(image_features, text_features.T)
842
+ predictions = similarities.argmax(dim=1)
843
+
844
+ # Compute accuracy if ground truth is available
845
+ results = {
846
+ "predictions": predictions.tolist(),
847
+ "similarities": similarities.tolist(),
848
+ "class_texts": class_texts
849
+ }
850
+
851
+ return results
852
+
853
+ def image_retrieval(self, query_text: str, image_paths: List[str], top_k: int = 5) -> List[Tuple[str, float]]:
854
+ """Retrieve top-k images for a text query"""
855
+ logger.info(f"Retrieving top-{top_k} images for query: '{query_text}'")
856
+
857
+ # Encode query and images
858
+ text_features = self.encode_texts([query_text])
859
+ image_features = self.encode_images(image_paths)
860
+
861
+ # Compute similarities
862
+ similarities = torch.matmul(text_features, image_features.T).squeeze()
863
+
864
+ # Get top-k results
865
+ top_k_indices = similarities.argsort(descending=True)[:top_k]
866
+
867
+ results = []
868
+ for idx in top_k_indices:
869
+ results.append((image_paths[idx.item()], similarities[idx.item()].item()))
870
+
871
+ return results
872
+
873
+
874
+ def main():
875
+ """Main training script"""
876
+ parser = argparse.ArgumentParser(description="Train PaveCLIP model")
877
+
878
+ # Model arguments
879
+ parser.add_argument("--model_type", default="clip", choices=["clip", "siglip"],
880
+ help="Model type to train")
881
+ parser.add_argument("--vision_model", default="vit-b/16",
882
+ help="Vision encoder (e.g., vit-b/16, vit-l/14@336, resnet50)")
883
+ parser.add_argument("--text_model", default="bert-base-uncased",
884
+ help="Text encoder (e.g., bert-base-uncased, roberta-base)")
885
+ parser.add_argument("--embed_dim", type=int, default=512,
886
+ help="Embedding dimension")
887
+ parser.add_argument("--vision_pretrained", action="store_true",
888
+ help="Use pretrained vision encoder")
889
+ parser.add_argument("--text_pretrained", action="store_true",
890
+ help="Use pretrained text encoder")
891
+
892
+ # Data arguments
893
+ parser.add_argument("--data_dir", required=True,
894
+ help="Path to Pavement_Pretraining_Data directory")
895
+ parser.add_argument("--val_split", type=float, default=0.1,
896
+ help="Validation split ratio")
897
+ parser.add_argument("--max_length", type=int, default=77,
898
+ help="Maximum text length")
899
+
900
+ # Training arguments
901
+ parser.add_argument("--batch_size", type=int, default=64,
902
+ help="Batch size")
903
+ parser.add_argument("--epochs", type=int, default=50,
904
+ help="Number of epochs")
905
+ parser.add_argument("--learning_rate", type=float, default=1e-4,
906
+ help="Learning rate")
907
+ parser.add_argument("--weight_decay", type=float, default=0.01,
908
+ help="Weight decay")
909
+ parser.add_argument("--temperature", type=float, default=0.07,
910
+ help="Temperature parameter")
911
+ parser.add_argument("--warmup_ratio", type=float, default=0.1,
912
+ help="Warmup ratio")
913
+
914
+ # System arguments
915
+ parser.add_argument("--num_workers", type=int, default=4,
916
+ help="Number of data loader workers")
917
+ parser.add_argument("--output_dir", default="./checkpoints",
918
+ help="Output directory for checkpoints")
919
+ parser.add_argument("--save_every", type=int, default=10,
920
+ help="Save checkpoint every N epochs")
921
+ parser.add_argument("--wandb", action="store_true",
922
+ help="Use Weights & Biases logging")
923
+ parser.add_argument("--distributed", action="store_true",
924
+ help="Enable distributed training")
925
+
926
+ args = parser.parse_args()
927
+
928
+ # Convert args to config dict
929
+ config = vars(args)
930
+
931
+ # Initialize trainer
932
+ trainer = PaveCLIPTrainer(config)
933
+
934
+ # Start training
935
+ trainer.train()
936
+
937
+ # Cleanup distributed training
938
+ if config.get("distributed", False):
939
+ dist.destroy_process_group()
940
+
941
+
942
+ if __name__ == "__main__":
943
+ main()
944
+
945
+
946
+ # python paveclip_training.py \
947
+ # --vision_model vit-b/16 \
948
+ # --text_model distilbert-base-uncased \
949
+ # --vision_pretrained \
950
+ # --text_pretrained \
951
+ # --data_dir ./Pavement_Pretraining_Data \
952
+ # --batch_size 64 \
953
+ # --epochs 100 \
954
+ # --wandb
955
+
956
+ # torchrun --nproc_per_node=4 paveclip_training.py \
957
+ # --distributed \
958
+ # [other args]
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=1.9.0
3
+ torchvision>=0.10.0
4
+ Pillow>=8.0.0
5
+ numpy>=1.21.0
6
+ pandas>=1.3.0
7
+ matplotlib>=3.5.0
8
+ seaborn>=0.11.0
9
+ scikit-learn>=1.0.0
10
+ plotly>=5.0.0
11
+ huggingface-hub>=0.16.0
12
+ transformers>=4.20.0
13
+ huggingface_hub>=0.16.0