carDamageDetection / model.py
rohithk-03
update return msg
21f6d56
raw
history blame
5.58 kB
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
import pandas as pd
import os
import cv2
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torch
import torch.nn as nn
class HybridCNNViT(nn.Module):
def __init__(self, in_channels: int, num_classes: int):
super(HybridCNNViT, self).__init__()
self.conv1 = nn.Conv2d(
in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(64, 128, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(128)
self.conv3 = nn.Conv2d(128, 128, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(128)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3,
stride=2, padding=1, bias=False)
self.bn4 = nn.BatchNorm2d(256)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn5 = nn.BatchNorm2d(256)
self.conv6 = nn.Conv2d(256, 512, kernel_size=3,
stride=1, padding=1, bias=False)
self.bn6 = nn.BatchNorm2d(512)
self.conv7 = nn.Conv2d(512, 512, kernel_size=3,
stride=2, padding=1, bias=False)
self.bn7 = nn.BatchNorm2d(512)
# Optional MaxPooling (can be removed if strictly no max pooling)
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.classifier_conv = nn.Conv2d(
512, num_classes, kernel_size=1, stride=1, padding=0, bias=False)
self.classifier = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Flatten(),
nn.Dropout(0.5)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.relu(self.bn3(self.conv3(x)))
x = self.relu(self.bn4(self.conv4(x)))
x = self.relu(self.bn5(self.conv5(x)))
x = self.relu(self.bn6(self.conv6(x)))
x = self.relu(self.bn7(self.conv7(x)))
x = self.maxpool(x) # Comment this line if no max pooling is needed
x = self.classifier_conv(x)
x = self.classifier(x)
return x
def load_and_pad_single_image(image_path, img_size=(224, 224)):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
img = cv2.imread(image_path)
if img is None:
raise ValueError(f"Could not read image: {image_path}")
img = cv2.resize(img, img_size)
return np.array(img)
def check_file(image_path):
# image_path = "d/Control-Axial/C-A (2).png"
# Load and preprocess the single image
image = load_and_pad_single_image(image_path)
image = np.expand_dims(image, axis=0) # Convert to batch format
# Duplicate the image 10 times
data = np.repeat(image, 10, axis=0)
# Normalize and transform the image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
0.229, 0.224, 0.225])
])
data = torch.tensor(data, dtype=torch.float32).permute(
0, 3, 1, 2).to(device)
# Placeholder labels for 10 images
labels = torch.tensor([0] * 10, dtype=torch.long).to(device)
data, labels = shuffle(data, labels, random_state=42)
train_data, test_data, train_labels, test_labels = train_test_split(
data, labels, test_size=0.2, random_state=42
)
train_labels = torch.tensor(train_labels, dtype=torch.long)
test_labels = torch.tensor(test_labels, dtype=torch.long)
batch_size = 1 # Since we are working with a single image
train_dataset = list(zip(train_data, train_labels))
test_dataset = list(zip(test_data, test_labels))
test_loader = DataLoader(
test_dataset, batch_size=batch_size, shuffle=False)
# Simple test with a model
output = ""
def test_model(model, test_loader, device):
model.to(device)
model.eval()
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
output = predicted
def remove_module_from_checkpoint(checkpoint):
new_state_dict = {}
for key, value in checkpoint["model_state_dict"].items():
new_key = key.replace("module.", "")
new_state_dict[new_key] = value
checkpoint["model_state_dict"] = new_state_dict
return checkpoint
model = HybridCNNViT(3, 2)
checkpoint = torch.load(
"/home/user/app/checkpoint32.pth", weights_only=False)
checkpoint = remove_module_from_checkpoint(checkpoint)
model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
model.eval()
model.to(device)
model = nn.DataParallel(model)
test_model(model, test_loader, device)
return output