Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from torchvision.models import resnet18, ResNet18_Weights | |
from torch.utils.data import DataLoader | |
import numpy as np | |
from core_dataset import CoreDataset | |
from config import BATCH_SIZE | |
class FeatureExtractor: | |
def __init__(self, device=None): | |
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model = self._load_model() | |
def _load_model(self): | |
"""Load pretrained ResNet18 and remove classification layer""" | |
weights = ResNet18_Weights.DEFAULT | |
model = resnet18(weights=weights) | |
# Remove the final classification layer | |
model = nn.Sequential(*list(model.children())[:-1]) | |
model = model.to(self.device) | |
model.eval() | |
return model | |
def extract_features(self, image_dir): | |
"""Extract features from all images in directory""" | |
dataset = CoreDataset(image_dir) | |
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False) | |
features = [] | |
image_paths = [] | |
print("Extracting features from images...") | |
with torch.no_grad(): | |
for batch, paths in dataloader: | |
batch = batch.to(self.device) | |
batch_features = self.model(batch) | |
batch_features = batch_features.view(batch_features.size(0), -1) | |
features.append(batch_features.cpu().numpy()) | |
image_paths.extend(paths) | |
features = np.vstack(features) | |
print(f"Extracted features shape: {features.shape}") | |
return features, image_paths | |
if __name__ == "__main__": | |
from config import IMAGE_DIR | |
extractor = FeatureExtractor() | |
features, paths = extractor.extract_features(IMAGE_DIR) | |
print(f"Extracted features for {len(paths)} images") | |