File size: 5,546 Bytes
9a5479a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import pandas as pd
import torch
from PIL import Image
from sklearn.metrics import classification_report, accuracy_score
from transformers import CLIPImageProcessor
import os
from tqdm import tqdm

# IMPORTANT: This line imports your custom model class from the training script.
# Ensure 'train_clip.py' is in the same directory.
from train_clip import MultiTaskClipVisionModel

# --- 1. Configuration ---
# Verify this path is correct. It should point to the directory where the
# 'pytorch_model.bin' and 'preprocessor_config.json' files for your best model are located.
MODEL_PATH = "./clip-fairface-finetuned/best_model" # Or "./clip-fairface-finetuned/checkpoint-XXXX"

VAL_CSV = './fairface_label_val.csv'
BASE_PATH = './'
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")
print(f"Loading model from: {MODEL_PATH}")

# --- 2. Load Label Mappings (must be identical to training) ---
# We load the TRAIN csv to ensure the label mappings are consistent with what the model was trained on.
train_df = pd.read_csv('./fairface_label_train.csv')
age_labels = sorted(train_df['age'].unique())
gender_labels = sorted(train_df['gender'].unique())
race_labels = sorted(train_df['race'].unique())

label_mappings = {
    'age': {label: i for i, label in enumerate(age_labels)},
    'gender': {label: i for i, label in enumerate(gender_labels)},
    'race': {label: i for i, label in enumerate(race_labels)},
}

# Create reverse mappings from ID back to human-readable label
id_mappings = {
    'age': {i: label for label, i in label_mappings['age'].items()},
    'gender': {i: label for label, i in label_mappings['gender'].items()},
    'race': {i: label for label, i in label_mappings['race'].items()},
}

NUM_LABELS = {
    'age': len(age_labels),
    'gender': len(gender_labels),
    'race': len(race_labels),
}


# --- 3. Load Model and Processor ---
print("Loading processor and model...")
processor = CLIPImageProcessor.from_pretrained(MODEL_PATH)
model = MultiTaskClipVisionModel(num_labels=NUM_LABELS)

# Load the saved model weights. `map_location` ensures it works even if you trained on GPU and now use CPU.
model.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'pytorch_model.bin'), map_location=torch.device(DEVICE)))
model.to(DEVICE)
model.eval() # Set the model to evaluation mode
print("Model loaded successfully.")


# --- 4. Evaluation on Validation Set ---
def evaluate_on_dataset():
    print(f"\nEvaluating on validation data from: {VAL_CSV}")
    val_df = pd.read_csv(VAL_CSV)

    # Lists to store all predictions and true labels
    all_preds = {'age': [], 'gender': [], 'race': []}
    all_true = {'age': [], 'gender': [], 'race': []}

    # Disable gradient calculations for efficiency
    with torch.no_grad():
        # Use tqdm for a nice progress bar
        for index, row in tqdm(val_df.iterrows(), total=val_df.shape[0], desc="Evaluating"):
            image_path = os.path.join(BASE_PATH, row['file'])
            image = Image.open(image_path).convert("RGB")

            # Process the image and move to the correct device
            inputs = processor(images=image, return_tensors="pt").to(DEVICE)

            # Get model predictions
            outputs = model(pixel_values=inputs['pixel_values'])
            logits = outputs['logits']

            # Process predictions for each task
            for task in ['age', 'gender', 'race']:
                pred_id = torch.argmax(logits[task], dim=-1).item()
                true_label = row[task]
                true_id = label_mappings[task][true_label]

                all_preds[task].append(pred_id)
                all_true[task].append(true_id)

    # --- Print Reports ---
    print("\n--- Evaluation Results ---")
    for task in ['age', 'gender', 'race']:
        task_preds = all_preds[task]
        task_true = all_true[task]
        task_labels = list(label_mappings[task].keys())
        task_target_names = [id_mappings[task][i] for i in range(len(task_labels))]

        accuracy = accuracy_score(task_true, task_preds)
        report = classification_report(
            task_true,
            task_preds,
            target_names=task_target_names,
            zero_division=0
        )

        print(f"\n--- {task.upper()} CLASSIFICATION REPORT ---")
        print(f"Overall Accuracy: {accuracy:.4f}")
        print(report)


# --- 5. Function for Single Image Prediction ---
def predict_single_image(image_path):
    print(f"\n--- Predicting for single image: {image_path} ---")
    if not os.path.exists(image_path):
        print(f"Error: Image path not found at '{image_path}'")
        return

    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(DEVICE)

    with torch.no_grad():
        outputs = model(pixel_values=inputs['pixel_values'])
        logits = outputs['logits']

    predictions = {}
    for task in ['age', 'gender', 'race']:
        pred_id = torch.argmax(logits[task], dim=-1).item()
        pred_label = id_mappings[task][pred_id]
        predictions[task] = pred_label

    print("Predictions:")
    for task, label in predictions.items():
        print(f"  - {task.capitalize()}: {label}")
    return predictions


if __name__ == "__main__":
    # Run the full evaluation on the validation dataset
    evaluate_on_dataset()

    # --- Example of single image prediction ---
    # IMPORTANT: Change this path to an image you want to test
    sample_image_path = 'val/1.jpg'
    predict_single_image(sample_image_path)