Spaces:
Sleeping
Sleeping
update_ver2
Browse files- app.py +44 -0
- cat_dog_resnet18.pth +3 -0
- dataset_prep_resnet18.py +65 -0
- model_resnet18.py +21 -0
- predict_resnet18.py +59 -0
- train_resnet18.py +119 -0
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Ứng dụng web phân loại ảnh Chó và Mèo sử dụng Streamlit
|
3 |
+
--------------------------------------------------
|
4 |
+
Ứng dụng này cung cấp giao diện web thân thiện để phân loại ảnh chó và mèo
|
5 |
+
sử dụng mô hình ResNet18 đã được huấn luyện. Người dùng có thể tải ảnh lên,
|
6 |
+
và ứng dụng sẽ đưa ra dự đoán kèm theo độ tin cậy của kết quả.
|
7 |
+
|
8 |
+
Chức năng chính:
|
9 |
+
- Tải lên ảnh từ máy của người dùng
|
10 |
+
- Hiển thị ảnh đã tải lên
|
11 |
+
- Sử dụng mô hình ResNet18 để phân loại ảnh
|
12 |
+
- Hiển thị kết quả phân loại (Chó/Mèo) và độ tin cậy của dự đoán
|
13 |
+
"""
|
14 |
+
import streamlit as st
|
15 |
+
from predict_resnet18 import predict_image
|
16 |
+
import tempfile
|
17 |
+
from PIL import Image
|
18 |
+
|
19 |
+
st.set_page_config(page_title="Phân loại Chó/Mèo", layout="centered")
|
20 |
+
st.title("🐾 Ứng dụng phân loại ảnh Chó và Mèo")
|
21 |
+
st.write("Tải ảnh lên để xem dự đoán mô hình phân loại là **Chó** hay **Mèo** 🐶🐱")
|
22 |
+
|
23 |
+
# Upload ảnh
|
24 |
+
uploaded_file = st.file_uploader("📷 Tải ảnh lên", type=["jpg", "jpeg", "png"])
|
25 |
+
|
26 |
+
if uploaded_file is not None:
|
27 |
+
try:
|
28 |
+
# Hiển thị ảnh
|
29 |
+
image = Image.open(uploaded_file).convert("RGB")
|
30 |
+
st.image(image, caption="Ảnh đã tải lên", use_container_width=True)
|
31 |
+
|
32 |
+
# Lưu tạm ảnh để truyền đường dẫn vào hàm predict_image
|
33 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp_file:
|
34 |
+
image.save(tmp_file.name)
|
35 |
+
tmp_path = tmp_file.name
|
36 |
+
|
37 |
+
# Dự đoán
|
38 |
+
with st.spinner("🔍 Đang phân tích..."):
|
39 |
+
result, confidence = predict_image(tmp_path)
|
40 |
+
st.success(f"✅ Kết quả: **{result}**")
|
41 |
+
st.info(f"🔒 Độ tin cậy: **{confidence:.2f}%**")
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
st.error(f"❌ Lỗi: {e}")
|
cat_dog_resnet18.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7d1a6e50bc25ade87fcd24fafcf3d5980069ae27747cd31e289740665ecd469b
|
3 |
+
size 44792236
|
dataset_prep_resnet18.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms, datasets
|
3 |
+
from torch.utils.data import DataLoader
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
|
7 |
+
# Hàm kiểm tra ảnh lỗi
|
8 |
+
def is_valid_image(filepath):
|
9 |
+
try:
|
10 |
+
with Image.open(filepath) as img:
|
11 |
+
img.verify()
|
12 |
+
img = Image.open(filepath).convert('RGB') # thử load RGB luôn
|
13 |
+
return True
|
14 |
+
except:
|
15 |
+
print(f"[!] Ảnh lỗi hoặc không hợp lệ: {filepath}")
|
16 |
+
return False
|
17 |
+
|
18 |
+
# Hàm dọn dữ liệu lỗi trong thư mục
|
19 |
+
def clean_dataset(directory):
|
20 |
+
for class_dir in os.listdir(directory):
|
21 |
+
class_path = os.path.join(directory, class_dir)
|
22 |
+
if os.path.isdir(class_path):
|
23 |
+
for img_name in os.listdir(class_path):
|
24 |
+
img_path = os.path.join(class_path, img_name)
|
25 |
+
if not is_valid_image(img_path):
|
26 |
+
os.remove(img_path)
|
27 |
+
|
28 |
+
# Gọi dọn ảnh lỗi trước khi tạo dataset
|
29 |
+
def get_data_loaders(data_dir='./data', batch_size=32):
|
30 |
+
print("🧹 Đang kiểm tra và loại bỏ ảnh lỗi...")
|
31 |
+
clean_dataset(os.path.join(data_dir, 'train'))
|
32 |
+
clean_dataset(os.path.join(data_dir, 'val'))
|
33 |
+
clean_dataset(os.path.join(data_dir, 'test'))
|
34 |
+
|
35 |
+
# Transform đúng chuẩn cho ResNet
|
36 |
+
train_transform = transforms.Compose([
|
37 |
+
transforms.RandomResizedCrop(224),
|
38 |
+
transforms.RandomHorizontalFlip(p=0.5),
|
39 |
+
transforms.RandomRotation(15),
|
40 |
+
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
41 |
+
transforms.ToTensor(),
|
42 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
43 |
+
std=[0.229, 0.224, 0.225])
|
44 |
+
])
|
45 |
+
|
46 |
+
val_transform = transforms.Compose([
|
47 |
+
transforms.Resize(256),
|
48 |
+
transforms.CenterCrop(224),
|
49 |
+
transforms.ToTensor(),
|
50 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
51 |
+
std=[0.229, 0.224, 0.225])
|
52 |
+
])
|
53 |
+
|
54 |
+
train_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=train_transform)
|
55 |
+
val_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=val_transform)
|
56 |
+
test_dataset = datasets.ImageFolder(root=os.path.join(data_dir, 'test'), transform=val_transform)
|
57 |
+
|
58 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
|
59 |
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
60 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
|
61 |
+
|
62 |
+
print("📂 Nhãn lớp:", train_dataset.classes)
|
63 |
+
print(f"🖼️ Số lượng ảnh: train = {len(train_dataset)}, val = {len(val_dataset)}, test = {len(test_dataset)}")
|
64 |
+
|
65 |
+
return train_loader, val_loader, test_loader
|
model_resnet18.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from torchvision.models import resnet18, ResNet18_Weights
|
3 |
+
|
4 |
+
class CatDogClassifier(nn.Module):
|
5 |
+
def __init__(self):
|
6 |
+
super(CatDogClassifier, self).__init__()
|
7 |
+
|
8 |
+
# Sử dụng pretrained weights chuẩn (ImageNet)
|
9 |
+
weights = ResNet18_Weights.DEFAULT
|
10 |
+
self.base_model = resnet18(weights=weights)
|
11 |
+
|
12 |
+
# Đóng băng toàn bộ layer (chỉ fine-tune fc layer)
|
13 |
+
for param in self.base_model.parameters():
|
14 |
+
param.requires_grad = False
|
15 |
+
|
16 |
+
# Thay thế lớp fully connected cuối bằng lớp phân loại 2 lớp
|
17 |
+
num_ftrs = self.base_model.fc.in_features
|
18 |
+
self.base_model.fc = nn.Linear(num_ftrs, 2)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return self.base_model(x)
|
predict_resnet18.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torchvision import transforms
|
3 |
+
from PIL import Image
|
4 |
+
from model_resnet18 import CatDogClassifier # dùng ResNet18
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
|
8 |
+
# Định nghĩa transform giống như lúc huấn luyện với ResNet18
|
9 |
+
transform = transforms.Compose([
|
10 |
+
transforms.Resize(256),
|
11 |
+
transforms.CenterCrop(224),
|
12 |
+
transforms.ToTensor(),
|
13 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
14 |
+
std=[0.229, 0.224, 0.225])
|
15 |
+
])
|
16 |
+
|
17 |
+
# Hàm dự đoán ảnh
|
18 |
+
def predict_image(image_path):
|
19 |
+
try:
|
20 |
+
# Kiểm tra tệp tồn tại
|
21 |
+
if not os.path.exists(image_path):
|
22 |
+
raise FileNotFoundError(f"Không tìm thấy ảnh: {image_path}")
|
23 |
+
|
24 |
+
# Tải model ResNet18
|
25 |
+
model = CatDogClassifier()
|
26 |
+
model.load_state_dict(torch.load("cat_dog_resnet18.pth", map_location=torch.device("cpu")))
|
27 |
+
model.eval()
|
28 |
+
|
29 |
+
# Load class_to_idx
|
30 |
+
with open("class_to_idx.json", "r") as f:
|
31 |
+
class_to_idx = json.load(f)
|
32 |
+
idx_to_class = {v: k for k, v in class_to_idx.items()}
|
33 |
+
|
34 |
+
# Xử lý ảnh
|
35 |
+
image = Image.open(image_path).convert("RGB")
|
36 |
+
image_tensor = transform(image).unsqueeze(0) # [1, 3, 224, 224]
|
37 |
+
|
38 |
+
# Dự đoán
|
39 |
+
with torch.no_grad():
|
40 |
+
outputs = model(image_tensor)
|
41 |
+
probs = torch.nn.functional.softmax(outputs, dim=1)
|
42 |
+
_, pred = torch.max(probs, 1)
|
43 |
+
confidence = probs[0][pred.item()].item()
|
44 |
+
|
45 |
+
label = idx_to_class[pred.item()]
|
46 |
+
emoji = "🐱" if "cat" in label.lower() else "🐶"
|
47 |
+
return f"{label.capitalize()} {emoji}", confidence * 100
|
48 |
+
|
49 |
+
except Exception as e:
|
50 |
+
raise RuntimeError(f"Lỗi khi dự đoán: {str(e)}")
|
51 |
+
|
52 |
+
# Chạy thử khi chạy trực tiếp
|
53 |
+
if __name__ == "__main__":
|
54 |
+
image_path = r"C:\Users\ADMIN\Desktop\Xulyanh2\data\test\cat\1359.jpg"
|
55 |
+
try:
|
56 |
+
result, confidence = predict_image(image_path)
|
57 |
+
print(f"Kết quả: {result} (độ tin cậy: {confidence:.2f}%)")
|
58 |
+
except Exception as e:
|
59 |
+
print(f"Lỗi: {e}")
|
train_resnet18.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.optim as optim
|
4 |
+
from tqdm import tqdm
|
5 |
+
from model_resnet18 import CatDogClassifier
|
6 |
+
from dataset_prep_resnet18 import get_data_loaders
|
7 |
+
import copy
|
8 |
+
import os
|
9 |
+
import json
|
10 |
+
|
11 |
+
# --- Cấu hình ---
|
12 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
EPOCHS = 20
|
14 |
+
BATCH_SIZE = 32
|
15 |
+
LEARNING_RATE = 0.001
|
16 |
+
DATA_DIR = './data'
|
17 |
+
PATIENCE = 5 # cho EarlyStopping
|
18 |
+
|
19 |
+
def train():
|
20 |
+
# Tải dữ liệu
|
21 |
+
train_loader, val_loader, test_loader = get_data_loaders(data_dir=DATA_DIR, batch_size=BATCH_SIZE)
|
22 |
+
|
23 |
+
# Khởi tạo mô hình (ResNet18)
|
24 |
+
model = CatDogClassifier().to(DEVICE)
|
25 |
+
|
26 |
+
# Loss và Optimizer
|
27 |
+
criterion = nn.CrossEntropyLoss()
|
28 |
+
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
|
29 |
+
|
30 |
+
# Lưu class_to_idx để dùng khi predict
|
31 |
+
class_to_idx = train_loader.dataset.class_to_idx # lấy từ ImageFolder
|
32 |
+
with open("class_to_idx.json", "w") as f:
|
33 |
+
json.dump(class_to_idx, f)
|
34 |
+
|
35 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
36 |
+
best_val_acc = 0.0
|
37 |
+
epochs_no_improve = 0
|
38 |
+
|
39 |
+
for epoch in range(EPOCHS):
|
40 |
+
print(f"\n📘 Epoch [{epoch+1}/{EPOCHS}]")
|
41 |
+
model.train()
|
42 |
+
running_loss = 0.0
|
43 |
+
correct, total = 0, 0
|
44 |
+
|
45 |
+
for images, labels in tqdm(train_loader, desc="Training"):
|
46 |
+
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
47 |
+
|
48 |
+
optimizer.zero_grad()
|
49 |
+
outputs = model(images)
|
50 |
+
loss = criterion(outputs, labels)
|
51 |
+
loss.backward()
|
52 |
+
optimizer.step()
|
53 |
+
|
54 |
+
running_loss += loss.item()
|
55 |
+
_, predicted = torch.max(outputs, 1)
|
56 |
+
total += labels.size(0)
|
57 |
+
correct += (predicted == labels).sum().item()
|
58 |
+
|
59 |
+
train_acc = 100. * correct / total
|
60 |
+
avg_train_loss = running_loss / len(train_loader)
|
61 |
+
|
62 |
+
# --- Validation ---
|
63 |
+
model.eval()
|
64 |
+
val_loss = 0.0
|
65 |
+
val_correct = 0
|
66 |
+
val_total = 0
|
67 |
+
with torch.no_grad():
|
68 |
+
for images, labels in val_loader:
|
69 |
+
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
70 |
+
outputs = model(images)
|
71 |
+
loss = criterion(outputs, labels)
|
72 |
+
val_loss += loss.item()
|
73 |
+
_, predicted = torch.max(outputs, 1)
|
74 |
+
val_total += labels.size(0)
|
75 |
+
val_correct += (predicted == labels).sum().item()
|
76 |
+
|
77 |
+
val_acc = 100. * val_correct / val_total
|
78 |
+
avg_val_loss = val_loss / len(val_loader)
|
79 |
+
|
80 |
+
print(f"✅ Train Acc: {train_acc:.2f}% | Loss: {avg_train_loss:.4f}")
|
81 |
+
print(f"🧪 Val Acc: {val_acc:.2f}% | Loss: {avg_val_loss:.4f}")
|
82 |
+
|
83 |
+
# --- ModelCheckpoint ---
|
84 |
+
if val_acc > best_val_acc:
|
85 |
+
best_val_acc = val_acc
|
86 |
+
best_model_wts = copy.deepcopy(model.state_dict())
|
87 |
+
torch.save(model.state_dict(), 'cat_dog_resnet18_ver2.pth')
|
88 |
+
print("💾 Đã lưu mô hình tốt nhất!")
|
89 |
+
epochs_no_improve = 0
|
90 |
+
else:
|
91 |
+
epochs_no_improve += 1
|
92 |
+
print(f"📌 Không cải thiện ({epochs_no_improve}/{PATIENCE})")
|
93 |
+
|
94 |
+
# --- EarlyStopping ---
|
95 |
+
if epochs_no_improve >= PATIENCE:
|
96 |
+
print("⏹️ Dừng sớm do không cải thiện validation accuracy.")
|
97 |
+
break
|
98 |
+
|
99 |
+
print(f"\n🎯 Huấn luyện hoàn tất. Val Acc tốt nhất: {best_val_acc:.2f}%")
|
100 |
+
|
101 |
+
# --- Test ---
|
102 |
+
model.load_state_dict(best_model_wts)
|
103 |
+
test_acc = evaluate(model, test_loader)
|
104 |
+
print(f"📊 Test Accuracy: {test_acc:.2f}%")
|
105 |
+
|
106 |
+
def evaluate(model, loader):
|
107 |
+
model.eval()
|
108 |
+
correct, total = 0, 0
|
109 |
+
with torch.no_grad():
|
110 |
+
for images, labels in loader:
|
111 |
+
images, labels = images.to(DEVICE), labels.to(DEVICE)
|
112 |
+
outputs = model(images)
|
113 |
+
_, predicted = torch.max(outputs, 1)
|
114 |
+
total += labels.size(0)
|
115 |
+
correct += (predicted == labels).sum().item()
|
116 |
+
return 100. * correct / total
|
117 |
+
|
118 |
+
if __name__ == '__main__':
|
119 |
+
train()
|