Phuneil commited on
Commit
0b3fbd2
·
verified ·
1 Parent(s): 2f8e4e9

update_ver2

Browse files
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Ứng dụng web phân loại ảnh Chó và Mèo sử dụng Streamlit
3
+ --------------------------------------------------
4
+ Ứng dụng này cung cấp giao diện web thân thiện để phân loại ảnh chó và mèo
5
+ sử dụng mô hình ResNet18 đã được huấn luyện. Người dùng có thể tải ảnh lên,
6
+ và ứng dụng sẽ đưa ra dự đoán kèm theo độ tin cậy của kết quả.
7
+
8
+ Chức năng chính:
9
+ - Tải lên ảnh từ máy của người dùng
10
+ - Hiển thị ảnh đã tải lên
11
+ - Sử dụng mô hình ResNet18 để phân loại ảnh
12
+ - Hiển thị kết quả phân loại (Chó/Mèo) và độ tin cậy của dự đoán
13
+ """
14
+ import streamlit as st
15
+ from predict_resnet18 import predict_image
16
+ import tempfile
17
+ from PIL import Image
18
+
19
+ st.set_page_config(page_title="Phân loại Chó/Mèo", layout="centered")
20
+ st.title("🐾 Ứng dụng phân loại ảnh Chó và Mèo")
21
+ st.write("Tải ảnh lên để xem dự đoán mô hình phân loại là **Chó** hay **Mèo** 🐶🐱")
22
+
23
+ # Upload ảnh
24
+ uploaded_file = st.file_uploader("📷 Tải ảnh lên", type=["jpg", "jpeg", "png"])
25
+
26
+ if uploaded_file is not None:
27
+ try:
28
+ # Hiển thị ảnh
29
+ image = Image.open(uploaded_file).convert("RGB")
30
+ st.image(image, caption="Ảnh đã tải lên", use_container_width=True)
31
+
32
+ # Lưu tạm ảnh để truyền đường dẫn vào hàm predict_image
33
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
34
+ image.save(tmp_file.name)
35
+ tmp_path = tmp_file.name
36
+
37
+ # Dự đoán
38
+ with st.spinner("🔍 Đang phân tích..."):
39
+ result, confidence = predict_image(tmp_path)
40
+ st.success(f"✅ Kết quả: **{result}**")
41
+ st.info(f"🔒 Độ tin cậy: **{confidence:.2f}%**")
42
+
43
+ except Exception as e:
44
+ st.error(f"❌ Lỗi: {e}")
cat_dog_resnet18.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7d1a6e50bc25ade87fcd24fafcf3d5980069ae27747cd31e289740665ecd469b
3
+ size 44792236
dataset_prep_resnet18.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms, datasets
3
+ from torch.utils.data import DataLoader
4
+ from PIL import Image
5
+ import os
6
+
7
+ # Hàm kiểm tra ảnh lỗi
8
+ def is_valid_image(filepath):
9
+ try:
10
+ with Image.open(filepath) as img:
11
+ img.verify()
12
+ img = Image.open(filepath).convert('RGB') # thử load RGB luôn
13
+ return True
14
+ except:
15
+ print(f"[!] Ảnh lỗi hoặc không hợp lệ: {filepath}")
16
+ return False
17
+
18
+ # Hàm dọn dữ liệu lỗi trong thư mục
19
+ def clean_dataset(directory):
20
+ for class_dir in os.listdir(directory):
21
+ class_path = os.path.join(directory, class_dir)
22
+ if os.path.isdir(class_path):
23
+ for img_name in os.listdir(class_path):
24
+ img_path = os.path.join(class_path, img_name)
25
+ if not is_valid_image(img_path):
26
+ os.remove(img_path)
27
+
28
+ # Gọi dọn ảnh lỗi trước khi tạo dataset
29
+ def get_data_loaders(data_dir='./data', batch_size=32):
30
+ print("🧹 Đang kiểm tra và loại bỏ ảnh lỗi...")
31
+ clean_dataset(os.path.join(data_dir, 'train'))
32
+ clean_dataset(os.path.join(data_dir, 'val'))
33
+ clean_dataset(os.path.join(data_dir, 'test'))
34
+
35
+ # Transform đúng chuẩn cho ResNet
36
+ train_transform = transforms.Compose([
37
+ transforms.RandomResizedCrop(224),
38
+ transforms.RandomHorizontalFlip(p=0.5),
39
+ transforms.RandomRotation(15),
40
+ transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
41
+ transforms.ToTensor(),
42
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
43
+ std=[0.229, 0.224, 0.225])
44
+ ])
45
+
46
+ val_transform = transforms.Compose([
47
+ transforms.Resize(256),
48
+ transforms.CenterCrop(224),
49
+ transforms.ToTensor(),
50
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
51
+ std=[0.229, 0.224, 0.225])
52
+ ])
53
+
54
+ train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=train_transform)
55
+ val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=val_transform)
56
+ test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=val_transform)
57
+
58
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
59
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
60
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
61
+
62
+ print("📂 Nhãn lớp:", train_dataset.classes)
63
+ print(f"🖼️ Số lượng ảnh: train = {len(train_dataset)}, val = {len(val_dataset)}, test = {len(test_dataset)}")
64
+
65
+ return train_loader, val_loader, test_loader
model_resnet18.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from torchvision.models import resnet18, ResNet18_Weights
3
+
4
+ class CatDogClassifier(nn.Module):
5
+ def __init__(self):
6
+ super(CatDogClassifier, self).__init__()
7
+
8
+ # Sử dụng pretrained weights chuẩn (ImageNet)
9
+ weights = ResNet18_Weights.DEFAULT
10
+ self.base_model = resnet18(weights=weights)
11
+
12
+ # Đóng băng toàn bộ layer (chỉ fine-tune fc layer)
13
+ for param in self.base_model.parameters():
14
+ param.requires_grad = False
15
+
16
+ # Thay thế lớp fully connected cuối bằng lớp phân loại 2 lớp
17
+ num_ftrs = self.base_model.fc.in_features
18
+ self.base_model.fc = nn.Linear(num_ftrs, 2)
19
+
20
+ def forward(self, x):
21
+ return self.base_model(x)
predict_resnet18.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ from model_resnet18 import CatDogClassifier # dùng ResNet18
5
+ import json
6
+ import os
7
+
8
+ # Định nghĩa transform giống như lúc huấn luyện với ResNet18
9
+ transform = transforms.Compose([
10
+ transforms.Resize(256),
11
+ transforms.CenterCrop(224),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
14
+ std=[0.229, 0.224, 0.225])
15
+ ])
16
+
17
+ # Hàm dự đoán ảnh
18
+ def predict_image(image_path):
19
+ try:
20
+ # Kiểm tra tệp tồn tại
21
+ if not os.path.exists(image_path):
22
+ raise FileNotFoundError(f"Không tìm thấy ảnh: {image_path}")
23
+
24
+ # Tải model ResNet18
25
+ model = CatDogClassifier()
26
+ model.load_state_dict(torch.load("cat_dog_resnet18.pth", map_location=torch.device("cpu")))
27
+ model.eval()
28
+
29
+ # Load class_to_idx
30
+ with open("class_to_idx.json", "r") as f:
31
+ class_to_idx = json.load(f)
32
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
33
+
34
+ # Xử lý ảnh
35
+ image = Image.open(image_path).convert("RGB")
36
+ image_tensor = transform(image).unsqueeze(0) # [1, 3, 224, 224]
37
+
38
+ # Dự đoán
39
+ with torch.no_grad():
40
+ outputs = model(image_tensor)
41
+ probs = torch.nn.functional.softmax(outputs, dim=1)
42
+ _, pred = torch.max(probs, 1)
43
+ confidence = probs[0][pred.item()].item()
44
+
45
+ label = idx_to_class[pred.item()]
46
+ emoji = "🐱" if "cat" in label.lower() else "🐶"
47
+ return f"{label.capitalize()} {emoji}", confidence * 100
48
+
49
+ except Exception as e:
50
+ raise RuntimeError(f"Lỗi khi dự đoán: {str(e)}")
51
+
52
+ # Chạy thử khi chạy trực tiếp
53
+ if __name__ == "__main__":
54
+ image_path = r"C:\Users\ADMIN\Desktop\Xulyanh2\data\test\cat\1359.jpg"
55
+ try:
56
+ result, confidence = predict_image(image_path)
57
+ print(f"Kết quả: {result} (độ tin cậy: {confidence:.2f}%)")
58
+ except Exception as e:
59
+ print(f"Lỗi: {e}")
train_resnet18.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from tqdm import tqdm
5
+ from model_resnet18 import CatDogClassifier
6
+ from dataset_prep_resnet18 import get_data_loaders
7
+ import copy
8
+ import os
9
+ import json
10
+
11
+ # --- Cấu hình ---
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ EPOCHS = 20
14
+ BATCH_SIZE = 32
15
+ LEARNING_RATE = 0.001
16
+ DATA_DIR = './data'
17
+ PATIENCE = 5 # cho EarlyStopping
18
+
19
+ def train():
20
+ # Tải dữ liệu
21
+ train_loader, val_loader, test_loader = get_data_loaders(data_dir=DATA_DIR, batch_size=BATCH_SIZE)
22
+
23
+ # Khởi tạo mô hình (ResNet18)
24
+ model = CatDogClassifier().to(DEVICE)
25
+
26
+ # Loss và Optimizer
27
+ criterion = nn.CrossEntropyLoss()
28
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
29
+
30
+ # Lưu class_to_idx để dùng khi predict
31
+ class_to_idx = train_loader.dataset.class_to_idx # lấy từ ImageFolder
32
+ with open("class_to_idx.json", "w") as f:
33
+ json.dump(class_to_idx, f)
34
+
35
+ best_model_wts = copy.deepcopy(model.state_dict())
36
+ best_val_acc = 0.0
37
+ epochs_no_improve = 0
38
+
39
+ for epoch in range(EPOCHS):
40
+ print(f"\n📘 Epoch [{epoch+1}/{EPOCHS}]")
41
+ model.train()
42
+ running_loss = 0.0
43
+ correct, total = 0, 0
44
+
45
+ for images, labels in tqdm(train_loader, desc="Training"):
46
+ images, labels = images.to(DEVICE), labels.to(DEVICE)
47
+
48
+ optimizer.zero_grad()
49
+ outputs = model(images)
50
+ loss = criterion(outputs, labels)
51
+ loss.backward()
52
+ optimizer.step()
53
+
54
+ running_loss += loss.item()
55
+ _, predicted = torch.max(outputs, 1)
56
+ total += labels.size(0)
57
+ correct += (predicted == labels).sum().item()
58
+
59
+ train_acc = 100. * correct / total
60
+ avg_train_loss = running_loss / len(train_loader)
61
+
62
+ # --- Validation ---
63
+ model.eval()
64
+ val_loss = 0.0
65
+ val_correct = 0
66
+ val_total = 0
67
+ with torch.no_grad():
68
+ for images, labels in val_loader:
69
+ images, labels = images.to(DEVICE), labels.to(DEVICE)
70
+ outputs = model(images)
71
+ loss = criterion(outputs, labels)
72
+ val_loss += loss.item()
73
+ _, predicted = torch.max(outputs, 1)
74
+ val_total += labels.size(0)
75
+ val_correct += (predicted == labels).sum().item()
76
+
77
+ val_acc = 100. * val_correct / val_total
78
+ avg_val_loss = val_loss / len(val_loader)
79
+
80
+ print(f"✅ Train Acc: {train_acc:.2f}% | Loss: {avg_train_loss:.4f}")
81
+ print(f"🧪 Val Acc: {val_acc:.2f}% | Loss: {avg_val_loss:.4f}")
82
+
83
+ # --- ModelCheckpoint ---
84
+ if val_acc > best_val_acc:
85
+ best_val_acc = val_acc
86
+ best_model_wts = copy.deepcopy(model.state_dict())
87
+ torch.save(model.state_dict(), 'cat_dog_resnet18_ver2.pth')
88
+ print("💾 Đã lưu mô hình tốt nhất!")
89
+ epochs_no_improve = 0
90
+ else:
91
+ epochs_no_improve += 1
92
+ print(f"📌 Không cải thiện ({epochs_no_improve}/{PATIENCE})")
93
+
94
+ # --- EarlyStopping ---
95
+ if epochs_no_improve >= PATIENCE:
96
+ print("⏹️ Dừng sớm do không cải thiện validation accuracy.")
97
+ break
98
+
99
+ print(f"\n🎯 Huấn luyện hoàn tất. Val Acc tốt nhất: {best_val_acc:.2f}%")
100
+
101
+ # --- Test ---
102
+ model.load_state_dict(best_model_wts)
103
+ test_acc = evaluate(model, test_loader)
104
+ print(f"📊 Test Accuracy: {test_acc:.2f}%")
105
+
106
+ def evaluate(model, loader):
107
+ model.eval()
108
+ correct, total = 0, 0
109
+ with torch.no_grad():
110
+ for images, labels in loader:
111
+ images, labels = images.to(DEVICE), labels.to(DEVICE)
112
+ outputs = model(images)
113
+ _, predicted = torch.max(outputs, 1)
114
+ total += labels.size(0)
115
+ correct += (predicted == labels).sum().item()
116
+ return 100. * correct / total
117
+
118
+ if __name__ == '__main__':
119
+ train()