PhanLoaiChoMeo / dataset_prep_resnet18.py
Phuneil's picture
update_ver2
0b3fbd2 verified
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from PIL import Image
import os
# Hàm kiểm tra ảnh lỗi
def is_valid_image(filepath):
try:
with Image.open(filepath) as img:
img.verify()
img = Image.open(filepath).convert('RGB') # thử load RGB luôn
return True
except:
print(f"[!] Ảnh lỗi hoặc không hợp lệ: {filepath}")
return False
# Hàm dọn dữ liệu lỗi trong thư mục
def clean_dataset(directory):
for class_dir in os.listdir(directory):
class_path = os.path.join(directory, class_dir)
if os.path.isdir(class_path):
for img_name in os.listdir(class_path):
img_path = os.path.join(class_path, img_name)
if not is_valid_image(img_path):
os.remove(img_path)
# Gọi dọn ảnh lỗi trước khi tạo dataset
def get_data_loaders(data_dir='./data', batch_size=32):
print("🧹 Đang kiểm tra và loại bỏ ảnh lỗi...")
clean_dataset(os.path.join(data_dir, 'train'))
clean_dataset(os.path.join(data_dir, 'val'))
clean_dataset(os.path.join(data_dir, 'test'))
# Transform đúng chuẩn cho ResNet
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=train_transform)
val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=val_transform)
test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
print("📂 Nhãn lớp:", train_dataset.classes)
print(f"🖼️ Số lượng ảnh: train = {len(train_dataset)}, val = {len(val_dataset)}, test = {len(test_dataset)}")
return train_loader, val_loader, test_loader