File size: 2,814 Bytes
0b3fbd2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from PIL import Image
import os

# Hàm kiểm tra ảnh lỗi
def is_valid_image(filepath):
    try:
        with Image.open(filepath) as img:
            img.verify()
            img = Image.open(filepath).convert('RGB')  # thử load RGB luôn
        return True
    except:
        print(f"[!] Ảnh lỗi hoặc không hợp lệ: {filepath}")
        return False

# Hàm dọn dữ liệu lỗi trong thư mục
def clean_dataset(directory):
    for class_dir in os.listdir(directory):
        class_path = os.path.join(directory, class_dir)
        if os.path.isdir(class_path):
            for img_name in os.listdir(class_path):
                img_path = os.path.join(class_path, img_name)
                if not is_valid_image(img_path):
                    os.remove(img_path)

# Gọi dọn ảnh lỗi trước khi tạo dataset
def get_data_loaders(data_dir='./data', batch_size=32):
    print("🧹 Đang kiểm tra và loại bỏ ảnh lỗi...")
    clean_dataset(os.path.join(data_dir, 'train'))
    clean_dataset(os.path.join(data_dir, 'val'))
    clean_dataset(os.path.join(data_dir, 'test'))

    # Transform đúng chuẩn cho ResNet
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=train_transform)
    val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=val_transform)
    test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=val_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    print("📂 Nhãn lớp:", train_dataset.classes)
    print(f"🖼️ Số lượng ảnh: train = {len(train_dataset)}, val = {len(val_dataset)}, test = {len(test_dataset)}")

    return train_loader, val_loader, test_loader