import torch from torchvision import transforms, datasets from torch.utils.data import DataLoader from PIL import Image import os # Hàm kiểm tra ảnh lỗi def is_valid_image(filepath): try: with Image.open(filepath) as img: img.verify() img = Image.open(filepath).convert('RGB') # thử load RGB luôn return True except: print(f"[!] Ảnh lỗi hoặc không hợp lệ: {filepath}") return False # Hàm dọn dữ liệu lỗi trong thư mục def clean_dataset(directory): for class_dir in os.listdir(directory): class_path = os.path.join(directory, class_dir) if os.path.isdir(class_path): for img_name in os.listdir(class_path): img_path = os.path.join(class_path, img_name) if not is_valid_image(img_path): os.remove(img_path) # Gọi dọn ảnh lỗi trước khi tạo dataset def get_data_loaders(data_dir='./data', batch_size=32): print("🧹 Đang kiểm tra và loại bỏ ảnh lỗi...") clean_dataset(os.path.join(data_dir, 'train')) clean_dataset(os.path.join(data_dir, 'val')) clean_dataset(os.path.join(data_dir, 'test')) # Transform đúng chuẩn cho ResNet train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=train_transform) val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=val_transform) test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=val_transform) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4) val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4) print("📂 Nhãn lớp:", train_dataset.classes) print(f"🖼️ Số lượng ảnh: train = {len(train_dataset)}, val = {len(val_dataset)}, test = {len(test_dataset)}") return train_loader, val_loader, test_loader