File size: 1,819 Bytes
71c32d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader
import numpy as np
from core_dataset import CoreDataset
from config import BATCH_SIZE

class FeatureExtractor:
    def __init__(self, device=None):
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self._load_model()

    def _load_model(self):
        """Load pretrained ResNet18 and remove classification layer"""
        weights = ResNet18_Weights.DEFAULT
        model = resnet18(weights=weights)
        # Remove the final classification layer
        model = nn.Sequential(*list(model.children())[:-1])
        model = model.to(self.device)
        model.eval()
        return model

    def extract_features(self, image_dir):
        """Extract features from all images in directory"""
        dataset = CoreDataset(image_dir)
        dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

        features = []
        image_paths = []

        print("Extracting features from images...")
        with torch.no_grad():
            for batch, paths in dataloader:
                batch = batch.to(self.device)
                batch_features = self.model(batch)
                batch_features = batch_features.view(batch_features.size(0), -1)
                features.append(batch_features.cpu().numpy())
                image_paths.extend(paths)

        features = np.vstack(features)
        print(f"Extracted features shape: {features.shape}")
        return features, image_paths

if __name__ == "__main__":
    from config import IMAGE_DIR
    extractor = FeatureExtractor()
    features, paths = extractor.extract_features(IMAGE_DIR)
    print(f"Extracted features for {len(paths)} images")