Geologist_AI / feature_extractor.py
solfedge's picture
Upload 9 files
71c32d5 verified
import torch
import torch.nn as nn
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader
import numpy as np
from core_dataset import CoreDataset
from config import BATCH_SIZE
class FeatureExtractor:
def __init__(self, device=None):
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = self._load_model()
def _load_model(self):
"""Load pretrained ResNet18 and remove classification layer"""
weights = ResNet18_Weights.DEFAULT
model = resnet18(weights=weights)
# Remove the final classification layer
model = nn.Sequential(*list(model.children())[:-1])
model = model.to(self.device)
model.eval()
return model
def extract_features(self, image_dir):
"""Extract features from all images in directory"""
dataset = CoreDataset(image_dir)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)
features = []
image_paths = []
print("Extracting features from images...")
with torch.no_grad():
for batch, paths in dataloader:
batch = batch.to(self.device)
batch_features = self.model(batch)
batch_features = batch_features.view(batch_features.size(0), -1)
features.append(batch_features.cpu().numpy())
image_paths.extend(paths)
features = np.vstack(features)
print(f"Extracted features shape: {features.shape}")
return features, image_paths
if __name__ == "__main__":
from config import IMAGE_DIR
extractor = FeatureExtractor()
features, paths = extractor.extract_features(IMAGE_DIR)
print(f"Extracted features for {len(paths)} images")