|
import pandas as pd |
|
import numpy as np |
|
import os |
|
from tqdm import tqdm |
|
from PIL import Image |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as T |
|
from torchvision.models import resnet50 |
|
|
|
def is_gpu_available(): |
|
return torch.cuda.is_available() |
|
|
|
class ResNetClassifier(nn.Module): |
|
def __init__(self, num_classes, metadata_size): |
|
super(ResNetClassifier, self).__init__() |
|
self.resnet = resnet50(pretrained=True) |
|
self.resnet.fc = nn.Identity() |
|
self.metadata_fc = nn.Linear(metadata_size, 128) |
|
self.classifier = nn.Linear(2048 + 128, num_classes) |
|
|
|
def forward(self, x, metadata_features): |
|
resnet_features = self.resnet(x) |
|
metadata_features = self.metadata_fc(metadata_features) |
|
combined_features = torch.cat((resnet_features, metadata_features), dim=1) |
|
logits = self.classifier(combined_features) |
|
return logits |
|
|
|
class PytorchWorker: |
|
def __init__(self, model_path: str, num_classes: int, metadata_size: int): |
|
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device: {self.device}") |
|
self.model = self._load_model(model_path, num_classes, metadata_size) |
|
self.transforms = T.Compose([T.Resize((224, 224)), |
|
T.ToTensor(), |
|
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) |
|
|
|
def _load_model(self, model_path, num_classes, metadata_size): |
|
model = ResNetClassifier(num_classes, metadata_size) |
|
model.load_state_dict(torch.load(model_path, map_location=self.device)) |
|
return model.to(self.device).eval() |
|
|
|
def predict_image(self, image: Image.Image, metadata_features: np.ndarray) -> list: |
|
input_tensor = self.transforms(image).unsqueeze(0).to(self.device) |
|
metadata_tensor = torch.tensor(metadata_features).unsqueeze(0).to(self.device) |
|
with torch.no_grad(): |
|
logits = self.model(input_tensor, metadata_tensor) |
|
return logits.tolist() |
|
|
|
def make_submission(test_metadata, model_path, num_classes, metadata_size, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"): |
|
model = PytorchWorker(model_path, num_classes, metadata_size) |
|
predictions = [] |
|
for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)): |
|
image_path = os.path.join(images_root_path, row['image_path']) |
|
test_image = Image.open(image_path).convert("RGB") |
|
metadata_features = row.drop(['image_path', 'class_id']).values.astype(np.float32) |
|
logits = model.predict_image(test_image, metadata_features) |
|
predictions.append(np.argmax(logits)) |
|
test_metadata["class_id"] = predictions |
|
user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first") |
|
user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None) |
|
|
|
if __name__ == "__main__": |
|
import zipfile |
|
with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref: |
|
zip_ref.extractall("/tmp/data") |
|
|
|
MODEL_PATH = "pytorch_model.pth" |
|
metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv" |
|
test_metadata = pd.read_csv(metadata_file_path) |
|
num_classes = 1784 |
|
metadata_size = len(test_metadata.columns) - 2 |
|
|
|
make_submission( |
|
test_metadata=test_metadata, |
|
model_path=MODEL_PATH, |
|
num_classes=num_classes, |
|
metadata_size=metadata_size |
|
) |
|
|