File size: 9,005 Bytes
9a5479a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
import pandas as pd
import torch
import torch.nn as nn
from PIL import Image
from sklearn.metrics import accuracy_score
from transformers import (
    Trainer,
    TrainingArguments,
    CLIPVisionModel,
    CLIPImageProcessor,
)
from torch.utils.data import Dataset
import os
os.environ["WANDB_DISABLED"] = "true"
# --- 1. Configuration ---
# Define paths and model name
BASE_PATH = './'  # Assumes the script is run from the 'fairface' directory
TRAIN_CSV = os.path.join(BASE_PATH, 'fairface_label_train.csv')
VAL_CSV = os.path.join(BASE_PATH, 'fairface_label_val.csv')
MODEL_NAME = "openai/clip-vit-large-patch14"
OUTPUT_DIR = "./clip-fairface-finetuned"

# --- 2. Load and Prepare Label Mappings ---
# Load training data to create consistent label-to-ID mappings
train_df = pd.read_csv(TRAIN_CSV)

# Create sorted unique label lists to ensure consistent mapping
age_labels = sorted(train_df['age'].unique())
gender_labels = sorted(train_df['gender'].unique())
race_labels = sorted(train_df['race'].unique())

# Create label-to-ID mappings for each task
label_mappings = {
    'age': {label: i for i, label in enumerate(age_labels)},
    'gender': {label: i for i, label in enumerate(gender_labels)},
    'race': {label: i for i, label in enumerate(race_labels)},
}

NUM_LABELS = {
    'age': len(age_labels),
    'gender': len(gender_labels),
    'race': len(race_labels),
}

print(f"Number of labels: Age={NUM_LABELS['age']}, Gender={NUM_LABELS['gender']}, Race={NUM_LABELS['race']}")

# --- 3. Custom Dataset ---
class FairFaceDataset(Dataset):
    def __init__(self, csv_file, image_processor, label_maps, base_path):
        self.df = pd.read_csv(csv_file)
        self.image_processor = image_processor
        self.label_maps = label_maps
        self.base_path = base_path

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # Construct the full path to the image
        image_path = os.path.join(self.base_path, row['file'])
        image = Image.open(image_path).convert("RGB")

        # Process the image
        inputs = {}
        inputs['pixel_values'] = self.image_processor(images=image, return_tensors="pt").pixel_values.squeeze(0)

        # Process labels into a dictionary of tensors
        inputs['labels'] = {
            'age': torch.tensor(self.label_maps['age'][row['age']], dtype=torch.long),
            'gender': torch.tensor(self.label_maps['gender'][row['gender']], dtype=torch.long),
            'race': torch.tensor(self.label_maps['race'][row['race']], dtype=torch.long),
        }
        return inputs

# --- 4. Custom Model Definition ---
# --- 4. Custom Model Definition (Corrected for Gradient Checkpointing) ---
class MultiTaskClipVisionModel(nn.Module):
    # Add this class attribute to signal to the Trainer that we support this
    supports_gradient_checkpointing = True

    def __init__(self, num_labels):
        super(MultiTaskClipVisionModel, self).__init__()
        self.vision_model = CLIPVisionModel.from_pretrained(MODEL_NAME)

        # Freeze all parameters of the vision model first
        for param in self.vision_model.parameters():
            param.requires_grad = False

        # Unfreeze the last few layers for fine-tuning.
        for layer in self.vision_model.vision_model.encoder.layers[-3:]: # Unfreeze last 3 transformer layers
             for param in layer.parameters():
                 param.requires_grad = True

        # Define classification heads for each task
        hidden_size = self.vision_model.config.hidden_size
        self.age_head = nn.Linear(hidden_size, num_labels['age'])
        self.gender_head = nn.Linear(hidden_size, num_labels['gender'])
        self.race_head = nn.Linear(hidden_size, num_labels['race'])

    # ADD THIS METHOD: This will be called by the Trainer
    def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
        """Activates gradient checkpointing for the underlying vision model."""
        self.vision_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

    def forward(self, pixel_values, labels=None):
        # The forward pass now works seamlessly with gradient checkpointing enabled
        outputs = self.vision_model(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output

        age_logits = self.age_head(pooled_output)
        gender_logits = self.gender_head(pooled_output)
        race_logits = self.race_head(pooled_output)

        loss = None
        # If labels are provided, calculate the combined loss
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            age_loss = loss_fct(age_logits, labels['age'])
            gender_loss = loss_fct(gender_logits, labels['gender'])
            race_loss = loss_fct(race_logits, labels['race'])
            # Total loss is the sum of individual task losses
            loss = age_loss + gender_loss + race_loss

        return {
            'loss': loss,
            'logits': {
                'age': age_logits,
                'gender': gender_logits,
                'race': race_logits,
            },
        }

# --- 5. Data Collator and Metrics ---
def collate_fn(batch):
    # Stacks pixel values and organizes labels into a dictionary of tensors
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = {
        'age': torch.tensor([item['labels']['age'] for item in batch], dtype=torch.long),
        'gender': torch.tensor([item['labels']['gender'] for item in batch], dtype=torch.long),
        'race': torch.tensor([item['labels']['race'] for item in batch], dtype=torch.long),
    }
    return {'pixel_values': pixel_values, 'labels': labels}

def compute_metrics(p):
    # p is an EvalPrediction object containing predictions and label_ids
    logits = p.predictions
    labels = p.label_ids

    # Extract predictions and labels for each task
    age_preds = logits['age'].argmax(-1)
    gender_preds = logits['gender'].argmax(-1)
    race_preds = logits['race'].argmax(-1)

    age_labels = labels['age']
    gender_labels = labels['gender']
    race_labels = labels['race']

    # Calculate accuracy for each task
    return {
        'age_accuracy': accuracy_score(age_labels, age_preds),
        'gender_accuracy': accuracy_score(gender_labels, gender_preds),
        'race_accuracy': accuracy_score(race_labels, race_preds),
    }

# --- 6. Trainer Setup and Execution ---
def main():
    # Initialize the image processor and our custom model
    image_processor = CLIPImageProcessor.from_pretrained(MODEL_NAME)
    model = MultiTaskClipVisionModel(num_labels=NUM_LABELS)

    # Initialize the training and validation datasets
    train_dataset = FairFaceDataset(
        csv_file=TRAIN_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH
    )
    val_dataset = FairFaceDataset(
        csv_file=VAL_CSV, image_processor=image_processor, label_maps=label_mappings, base_path=BASE_PATH
    )

    # Define the training arguments
    # In your main() function, replace the old TrainingArguments with this one

    # Define the training arguments
    training_args = TrainingArguments(
        output_dir=OUTPUT_DIR,
        num_train_epochs=5,
        # Set a batch size that fits in memory
        per_device_train_batch_size=24,
        per_device_eval_batch_size=32,  # Evaluation does not need accumulation and can use a larger batch size
        # Set accumulation steps to reach the desired effective batch size (24 * 22 = 528)
        gradient_accumulation_steps=22,
        # Enable gradient checkpointing to save more memory
        gradient_checkpointing=True,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir='./logs',
        logging_steps=10,  # Log more frequently to see progress within a large effective batch
        evaluation_strategy="steps",
        eval_steps=250, # You might want to evaluate less frequently with larger batches
        save_strategy="steps",
        save_steps=250,
        load_best_model_at_end=True,
        metric_for_best_model='gender_accuracy',
        save_total_limit=3,
        fp16=True,  # Mixed-precision training is essential for large models
        remove_unused_columns=False,
        report_to="none", # Disables wandb logging
    )

    # Initialize the Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
    )

    # Start training
    print("Starting model training...")
    trainer.train()

    # Save the final model and processor
    print("Saving the best model...")
    trainer.save_model(os.path.join(OUTPUT_DIR, "best_model"))
    image_processor.save_pretrained(os.path.join(OUTPUT_DIR, "best_model"))

    print("Training complete!")

if __name__ == "__main__":
    main()