resnet-snake-clef / script.py
parthiban12's picture
Create script.py
c40d23a verified
raw
history blame
3.6 kB
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
from PIL import Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from torchvision.models import resnet50
def is_gpu_available():
return torch.cuda.is_available()
class ResNetClassifier(nn.Module):
def __init__(self, num_classes, metadata_size):
super(ResNetClassifier, self).__init__()
self.resnet = resnet50(pretrained=True)
self.resnet.fc = nn.Identity() # Remove the fully connected layer
self.metadata_fc = nn.Linear(metadata_size, 128)
self.classifier = nn.Linear(2048 + 128, num_classes) # 2048 is the output size of ResNet50
def forward(self, x, metadata_features):
resnet_features = self.resnet(x)
metadata_features = self.metadata_fc(metadata_features)
combined_features = torch.cat((resnet_features, metadata_features), dim=1)
logits = self.classifier(combined_features)
return logits
class PytorchWorker:
def __init__(self, model_path: str, num_classes: int, metadata_size: int):
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
self.model = self._load_model(model_path, num_classes, metadata_size)
self.transforms = T.Compose([T.Resize((224, 224)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
def _load_model(self, model_path, num_classes, metadata_size):
model = ResNetClassifier(num_classes, metadata_size)
model.load_state_dict(torch.load(model_path, map_location=self.device))
return model.to(self.device).eval()
def predict_image(self, image: Image.Image, metadata_features: np.ndarray) -> list:
input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
metadata_tensor = torch.tensor(metadata_features).unsqueeze(0).to(self.device)
with torch.no_grad():
logits = self.model(input_tensor, metadata_tensor)
return logits.tolist()
def make_submission(test_metadata, model_path, num_classes, metadata_size, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
model = PytorchWorker(model_path, num_classes, metadata_size)
predictions = []
for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
image_path = os.path.join(images_root_path, row['image_path'])
test_image = Image.open(image_path).convert("RGB")
metadata_features = row.drop(['image_path', 'class_id']).values.astype(np.float32)
logits = model.predict_image(test_image, metadata_features)
predictions.append(np.argmax(logits))
test_metadata["class_id"] = predictions
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
if __name__ == "__main__":
import zipfile
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
zip_ref.extractall("/tmp/data")
MODEL_PATH = "pytorch_model.pth"
metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
test_metadata = pd.read_csv(metadata_file_path)
num_classes = 1784
metadata_size = len(test_metadata.columns) - 2 # Excluding 'image_path' and 'class_id'
make_submission(
test_metadata=test_metadata,
model_path=MODEL_PATH,
num_classes=num_classes,
metadata_size=metadata_size
)