Spaces:
Running
Running
Commit
Β·
54f521b
0
Parent(s):
Initial clean commit for Hugging Face
Browse files- .gitattributes +35 -0
- .gitignore +1 -0
- Dockerfile +27 -0
- app.py +239 -0
- readme.md +180 -0
- requirements.txt +11 -0
- resnet18_cat.pth +3 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
data/
|
Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && apt-get install -y \
|
7 |
+
build-essential \
|
8 |
+
curl \
|
9 |
+
git \
|
10 |
+
&& rm -rf /var/lib/apt/lists/*
|
11 |
+
|
12 |
+
# Copy files
|
13 |
+
COPY requirements.txt .
|
14 |
+
COPY app.py .
|
15 |
+
COPY resnet18_cat.pth .
|
16 |
+
|
17 |
+
# Install Python dependencies
|
18 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
19 |
+
|
20 |
+
# Expose Streamlit default port
|
21 |
+
EXPOSE 7860
|
22 |
+
|
23 |
+
# Healthcheck (optional)
|
24 |
+
HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health || exit 1
|
25 |
+
|
26 |
+
# Run the Streamlit app using app.py
|
27 |
+
CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0"]
|
app.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, types, os
|
2 |
+
sys.modules['torch.classes'] = types.SimpleNamespace()
|
3 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
4 |
+
os.environ["STREAMLIT_WATCH_SYSTEM_FOLDERS"] = "false"
|
5 |
+
|
6 |
+
import streamlit as st
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import pickle
|
10 |
+
from torchvision import models, transforms
|
11 |
+
from torchcam.methods import GradCAM
|
12 |
+
from torchcam.utils import overlay_mask
|
13 |
+
from torchvision.transforms.functional import to_pil_image
|
14 |
+
from sklearn.metrics.pairwise import cosine_distances
|
15 |
+
import matplotlib.pyplot as plt
|
16 |
+
import gzip
|
17 |
+
from huggingface_hub import hf_hub_download
|
18 |
+
|
19 |
+
st.set_page_config(layout="wide")
|
20 |
+
|
21 |
+
st.markdown("""
|
22 |
+
<style>
|
23 |
+
/* Widen content area */
|
24 |
+
.block-container {
|
25 |
+
padding: 3rem 5rem;
|
26 |
+
max-width: 95%;
|
27 |
+
}
|
28 |
+
|
29 |
+
/* Headings */
|
30 |
+
h1 {
|
31 |
+
font-size: 2.5rem !important;
|
32 |
+
margin-bottom: 0.75rem;
|
33 |
+
}
|
34 |
+
h2 {
|
35 |
+
font-size: 2rem !important;
|
36 |
+
margin-top: 2rem;
|
37 |
+
margin-bottom: 0.5rem;
|
38 |
+
}
|
39 |
+
h3 {
|
40 |
+
font-size: 1.5rem !important;
|
41 |
+
margin-bottom: 0.5rem;
|
42 |
+
}
|
43 |
+
|
44 |
+
/* Paragraphs */
|
45 |
+
p, li {
|
46 |
+
font-size: 1.15rem !important;
|
47 |
+
line-height: 1.7;
|
48 |
+
margin-bottom: 1rem;
|
49 |
+
text-align: left;
|
50 |
+
}
|
51 |
+
|
52 |
+
/* Sidebar tweaks */
|
53 |
+
section[data-testid="stSidebar"] {
|
54 |
+
font-size: 1rem !important;
|
55 |
+
}
|
56 |
+
|
57 |
+
/* Metric values */
|
58 |
+
[data-testid="stMetricValue"] {
|
59 |
+
font-size: 1.8rem !important;
|
60 |
+
color: #21c6b6;
|
61 |
+
}
|
62 |
+
</style>
|
63 |
+
""", unsafe_allow_html=True)
|
64 |
+
|
65 |
+
|
66 |
+
|
67 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
68 |
+
CLASS_NAMES = ["fake", "real"]
|
69 |
+
|
70 |
+
@st.cache_resource
|
71 |
+
def load_model():
|
72 |
+
model = models.resnet18(weights=None)
|
73 |
+
model.fc = torch.nn.Linear(model.fc.in_features, 2)
|
74 |
+
model.load_state_dict(torch.load("resnet18_cat.pth", map_location=device))
|
75 |
+
model.to(device).eval()
|
76 |
+
return model
|
77 |
+
|
78 |
+
|
79 |
+
@st.cache_data
|
80 |
+
def load_embeddings():
|
81 |
+
# Download the file from your dataset repo
|
82 |
+
path = hf_hub_download(
|
83 |
+
repo_id="aryanj10/cat-xai-embeddings",
|
84 |
+
filename="embeddings.pkl.gz",
|
85 |
+
repo_type="dataset"
|
86 |
+
)
|
87 |
+
|
88 |
+
# Load gzip-compressed pickle file
|
89 |
+
with gzip.open(path, "rb") as f:
|
90 |
+
return pickle.load(f)
|
91 |
+
|
92 |
+
|
93 |
+
model = load_model()
|
94 |
+
|
95 |
+
# π Register hook for feature extraction
|
96 |
+
feature_extractor = {}
|
97 |
+
def hook_fn(module, input, output):
|
98 |
+
feature_extractor['features'] = output.flatten(1).detach()
|
99 |
+
model.avgpool.register_forward_hook(hook_fn)
|
100 |
+
|
101 |
+
cam_extractor = GradCAM(model, target_layer="layer4")
|
102 |
+
data = load_embeddings()
|
103 |
+
|
104 |
+
# β© Sidebar
|
105 |
+
st.sidebar.title("π§ Options")
|
106 |
+
st.sidebar.markdown("---")
|
107 |
+
st.sidebar.markdown("π **Try These Indices**")
|
108 |
+
st.sidebar.markdown("""
|
109 |
+
- Fake β‘ Real: `13`, `18`, `34`, `40`
|
110 |
+
- Real β‘ Fake: `57`, `77`, `80`
|
111 |
+
Use them to explore how the model makes errors, and how XAI helps analyze them.
|
112 |
+
""")
|
113 |
+
test_idx = st.sidebar.number_input("Select Test Index", min_value=0, max_value=len(data["test_images"])-1, value=57)
|
114 |
+
k_twins = st.sidebar.slider("Number of Twin Matches", 1, 10, value=3)
|
115 |
+
|
116 |
+
# π― Load and predict
|
117 |
+
query_tensor = data["test_images"][test_idx]
|
118 |
+
true_label = data["test_labels"][test_idx]
|
119 |
+
query_input = query_tensor.unsqueeze(0).to(device)
|
120 |
+
pred_label = model(query_input).argmax(1).item()
|
121 |
+
|
122 |
+
important_cases = {
|
123 |
+
"Real β‘ Fake": [57, 77, 80],
|
124 |
+
"Fake β‘ Real": [13, 18, 34, 40]
|
125 |
+
}
|
126 |
+
for label, indices in important_cases.items():
|
127 |
+
if test_idx in indices:
|
128 |
+
st.info(f"π You're viewing a **{label}** misclassification test case.")
|
129 |
+
|
130 |
+
|
131 |
+
# π₯ Grad-CAM
|
132 |
+
cam = cam_extractor(pred_label, model(query_input))[0].detach().cpu()
|
133 |
+
img_display = query_tensor.cpu() * 0.5 + 0.5
|
134 |
+
img_display = torch.clamp(img_display, 0, 1)
|
135 |
+
orig_img = to_pil_image(img_display)
|
136 |
+
cam_overlay = overlay_mask(orig_img, to_pil_image(cam.squeeze(0), mode='F'), alpha=0.5)
|
137 |
+
|
138 |
+
# πΌ Grad-CAM visual
|
139 |
+
st.markdown("# π± Fake-vs-Real Cat Classifier Explanation")
|
140 |
+
|
141 |
+
st.markdown("""
|
142 |
+
## π§ What This App Does
|
143 |
+
|
144 |
+
This demo uses **Explainable AI (XAI)** techniques to classify cat images as either **Real** or **Fake** (AI-generated).
|
145 |
+
It reveals not only the **prediction**, but also the **reasoning** behind the decision using two complementary explainability tools:
|
146 |
+
|
147 |
+
### π₯ Grad-CAM
|
148 |
+
Shows **where** the model is "looking" by highlighting important image regions (like eyes, ears, fur).
|
149 |
+
|
150 |
+
### 𧬠Twin System
|
151 |
+
Explains **why** a decision was made by comparing the test image's internal embedding with similar training examples.
|
152 |
+
Think: _βThis image is close to these β and the model predicted them like this.β_
|
153 |
+
|
154 |
+
---
|
155 |
+
|
156 |
+
### π§© Setup Summary
|
157 |
+
|
158 |
+
- **Model**: Fine-tuned `ResNet-18`
|
159 |
+
- **Classes**: `real` = 1 `fake` = 0
|
160 |
+
- **Dataset**: 300 images (150 real + 150 fake)
|
161 |
+
- **Validation Accuracy**: 91%
|
162 |
+
""")
|
163 |
+
|
164 |
+
|
165 |
+
|
166 |
+
st.markdown("## π― Grad-CAM Explanation")
|
167 |
+
|
168 |
+
st.markdown("""
|
169 |
+
Grad-CAM (**Gradient-weighted Class Activation Mapping**) visualizes **where** the model focuses when making a prediction.
|
170 |
+
It uses backpropagation through the final convolutional layer to produce a **heatmap** of important regions.
|
171 |
+
|
172 |
+
π In the images below:
|
173 |
+
- **Left**: Original test image
|
174 |
+
- **Right**: Grad-CAM overlay β bright areas = higher model attention
|
175 |
+
|
176 |
+
This helps you answer:
|
177 |
+
|
178 |
+
> _βWhat part of the image led the model to this decision?β_
|
179 |
+
""")
|
180 |
+
st.subheader("π₯ Grad-CAM Visualization")
|
181 |
+
|
182 |
+
cols = st.columns(2)
|
183 |
+
cols[0].image(orig_img, caption="Original Image", use_container_width=True)
|
184 |
+
cols[1].image(cam_overlay, caption=f"Grad-CAM (Pred: {CLASS_NAMES[pred_label]})", use_container_width=True)
|
185 |
+
|
186 |
+
# 𧬠Twin System
|
187 |
+
st.markdown("## 𧬠Twin System Explanation")
|
188 |
+
|
189 |
+
st.markdown("""
|
190 |
+
The **Twin System** provides an example-based explanation by retrieving **visually similar training images**.
|
191 |
+
|
192 |
+
### π How it Works:
|
193 |
+
- Embeddings from the `avgpool` layer are compared using **cosine similarity**
|
194 |
+
- The top `k` most similar training images are shown
|
195 |
+
- We display both the **true label** and **model prediction** for each
|
196 |
+
|
197 |
+
This helps you answer:
|
198 |
+
|
199 |
+
> _βWhat similar examples in the training set justify the modelβs current prediction?β_
|
200 |
+
""")
|
201 |
+
st.subheader("π Twin System Visualization")
|
202 |
+
query_emb = data["test_embeddings"][test_idx].reshape(1, -1)
|
203 |
+
distances = cosine_distances(query_emb, data["train_embeddings"])[0]
|
204 |
+
same_class_idxs = np.where(data["train_labels"] == pred_label)[0]
|
205 |
+
nearest_same = same_class_idxs[np.argsort(distances[same_class_idxs])[:k_twins]]
|
206 |
+
|
207 |
+
twin_panel = [("Query", query_tensor, true_label)] + [
|
208 |
+
("Twin", data["train_images"][i], data["train_labels"][i]) for i in nearest_same
|
209 |
+
]
|
210 |
+
|
211 |
+
row1, row2 = st.columns(len(twin_panel)), st.columns(len(twin_panel))
|
212 |
+
for i, (title, img_tensor, label) in enumerate(twin_panel):
|
213 |
+
img_tensor_input = img_tensor.unsqueeze(0).to(device).requires_grad_()
|
214 |
+
pred = model(img_tensor_input).argmax(1).item()
|
215 |
+
|
216 |
+
cam = cam_extractor(pred, model(img_tensor_input))[0].detach().cpu()
|
217 |
+
img_display = img_tensor.cpu() * 0.5 + 0.5
|
218 |
+
img_display = torch.clamp(img_display, 0, 1)
|
219 |
+
orig = to_pil_image(img_display)
|
220 |
+
cam_overlay = overlay_mask(orig, to_pil_image(cam.squeeze(0), mode='F'), alpha=0.5)
|
221 |
+
|
222 |
+
row1[i].image(orig, caption=f"{title}\nLabel: {CLASS_NAMES[label]}", use_container_width=True)
|
223 |
+
row2[i].image(cam_overlay, caption=f"Grad-CAM\nPred: {CLASS_NAMES[pred]}", use_container_width=True)
|
224 |
+
|
225 |
+
cam_extractor._hooks_enabled = False
|
226 |
+
|
227 |
+
st.markdown("---")
|
228 |
+
st.markdown("### π Model Performance")
|
229 |
+
col1, col2 = st.columns(2)
|
230 |
+
with col1:
|
231 |
+
st.metric("Accuracy", "91%")
|
232 |
+
st.metric("F1 Score", "0.91")
|
233 |
+
with col2:
|
234 |
+
st.metric("Precision (Real)", "0.94")
|
235 |
+
st.metric("Recall (Real)", "0.88")
|
236 |
+
st.metric("Precision (Fake)", "0.89")
|
237 |
+
st.metric("Recall (Fake)", "0.94")
|
238 |
+
|
239 |
+
|
readme.md
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Explainable AI for Image Classification using Twin System and Grad-CAM
|
2 |
+
|
3 |
+
**Author:** Aryan Jain
|
4 |
+
**Date:** May 17, 2025
|
5 |
+
|
6 |
+
---
|
7 |
+
|
8 |
+
## π§ Introduction
|
9 |
+
|
10 |
+
Explainable Artificial Intelligence (XAI) is crucial in making machine learning models interpretable and trustworthy, especially in high-stakes or opaque domains. In this project, I applied two complementary post-hoc XAI techniques to a binary image classification problem β distinguishing between **real** and **AI-generated** (fake) cat images.
|
11 |
+
|
12 |
+
The objective was two-fold:
|
13 |
+
|
14 |
+
1. Build a strong classifier using a fine-tuned ResNet-18 model.
|
15 |
+
2. Reveal the "why" behind its predictions using visual and example-based explanations.
|
16 |
+
|
17 |
+
---
|
18 |
+
|
19 |
+
## π§° XAI Techniques Used
|
20 |
+
|
21 |
+
### π₯ Grad-CAM (Gradient-weighted Class Activation Mapping)
|
22 |
+
|
23 |
+
Grad-CAM highlights which parts of an image are most influential for the model's prediction. It works by backpropagating gradients to the final convolutional layers to generate heatmaps. These maps help users verify whether the model is focusing on the correct regions (e.g., ears, eyes, textures) β or being misled.
|
24 |
+
|
25 |
+
β
**Why use it?**
|
26 |
+
To visualize the evidence used by the network for classification decisions.
|
27 |
+
|
28 |
+
π [Grad-CAM Paper](https://arxiv.org/abs/1610.02391)
|
29 |
+
|
30 |
+
---
|
31 |
+
|
32 |
+
### π§ Twin System (Case-Based Reasoning with Embeddings)
|
33 |
+
|
34 |
+
The Twin System explains model predictions by retrieving visually similar training samples with the same predicted class. It computes cosine similarity between embeddings from the penultimate layer (`avgpool`) and shows the "nearest neighbors" from the training set.
|
35 |
+
|
36 |
+
β
**Why use it?**
|
37 |
+
To provide intuitive, example-based justification β similar to how a radiologist compares a scan with past known cases.
|
38 |
+
|
39 |
+
π Based on: [This Looks Like That (2018)](https://arxiv.org/abs/1806.10574)
|
40 |
+
|
41 |
+
---
|
42 |
+
|
43 |
+
## π Dataset
|
44 |
+
|
45 |
+
- **Total Images:** 300
|
46 |
+
- 150 Real cat images from public datasets
|
47 |
+
- 150 Fake images generated using [`google/ddpm-cat-256`](https://arxiv.org/abs/2006.11239)
|
48 |
+
|
49 |
+
### Preprocessing
|
50 |
+
|
51 |
+
- Resized to `224x224`
|
52 |
+
- Normalized using mean and std = 0.5 (CIFAR-style)
|
53 |
+
|
54 |
+
### Train/Validation Split
|
55 |
+
|
56 |
+
- **Training:** 100 real + 100 fake
|
57 |
+
- **Validation:** 50 real + 50 fake
|
58 |
+
|
59 |
+
---
|
60 |
+
|
61 |
+
## ποΈ Model Architecture
|
62 |
+
|
63 |
+
- **Backbone:** Pretrained ResNet-18
|
64 |
+
- Final layer modified for 2-class output (real vs. fake)
|
65 |
+
|
66 |
+
### Training Configuration
|
67 |
+
|
68 |
+
- Optimizer: Adam (`lr=1e-4`)
|
69 |
+
- Loss Function: CrossEntropyLoss
|
70 |
+
- Epochs: 10
|
71 |
+
- Batch Size: 32
|
72 |
+
|
73 |
+
β
**Final Validation Accuracy:** 91%
|
74 |
+
|
75 |
+
---
|
76 |
+
|
77 |
+
### π Evaluation Metrics
|
78 |
+
|
79 |
+
- **Accuracy:** 91%
|
80 |
+
- **Precision (Real):** 0.94, **Recall (Real):** 0.88
|
81 |
+
- **Precision (Fake):** 0.89, **Recall (Fake):** 0.94
|
82 |
+
- **F1 Score (Both):** 0.91
|
83 |
+
|
84 |
+

|
85 |
+
|
86 |
+
---
|
87 |
+
|
88 |
+
## π¬ Experiments
|
89 |
+
|
90 |
+
### π Grad-CAM Visualizations
|
91 |
+
|
92 |
+
Visual saliency maps for real and fake image samples:
|
93 |
+
|
94 |
+
- 
|
95 |
+
- 
|
96 |
+
- 
|
97 |
+
- 
|
98 |
+
- 
|
99 |
+
- 
|
100 |
+
|
101 |
+
These overlays revealed that the model often attended to fur patterns, eye shapes, and ear positions for its predictions.
|
102 |
+
|
103 |
+
---
|
104 |
+
|
105 |
+
### π§ Twin System Visualizations
|
106 |
+
|
107 |
+
Twin explanations retrieved training examples similar to a test image **from the same predicted class**, validating the classifier's reasoning.
|
108 |
+
|
109 |
+
- 
|
110 |
+
- 
|
111 |
+
- 
|
112 |
+
- 
|
113 |
+
- 
|
114 |
+
- 
|
115 |
+
|
116 |
+
---
|
117 |
+
|
118 |
+
### β Misclassification Analysis
|
119 |
+
|
120 |
+
- **Real β‘ Fake (False Negatives):** `[13, 18, 22, 34, 40, 44]`
|
121 |
+
- **Fake β‘ Real (False Positives):** `[57, 77, 80]`
|
122 |
+
|
123 |
+
Grad-CAM and Twin System helped investigate these edge cases. In many cases, blur or unusual poses in real images confused the classifier.
|
124 |
+
|
125 |
+
---
|
126 |
+
|
127 |
+
## β
Conclusion
|
128 |
+
|
129 |
+
By combining Grad-CAM with the Twin System, this project achieved a richer interpretability framework:
|
130 |
+
|
131 |
+
| Technique | Purpose | Value |
|
132 |
+
|----------|---------|-------|
|
133 |
+
| **Grad-CAM** | Pixel-level explanation | Shows *where* the model is looking |
|
134 |
+
| **Twin System** | Example-based reasoning | Shows *why* via similarity to past cases |
|
135 |
+
|
136 |
+
This multi-view approach fosters transparency and trust in AI-powered image classifiers.
|
137 |
+
|
138 |
+
---
|
139 |
+
|
140 |
+
## π Future Work
|
141 |
+
|
142 |
+
- Introduce **counterfactual explanations** (e.g., nearest neighbors from the opposite class)
|
143 |
+
- Replace cosine similarity with **CLIP embeddings** for semantic similarity
|
144 |
+
- Improve Twin System with a **ProtoPNet architecture**
|
145 |
+
|
146 |
+
---
|
147 |
+
|
148 |
+
### π§ͺ ProtoPNet Attempt
|
149 |
+
|
150 |
+
- Architecture: ResNet-18 backbone with 10 learned prototypes per class
|
151 |
+
- Goal: Learn localized regions (patches) that support classification
|
152 |
+
|
153 |
+
#### Results
|
154 |
+
|
155 |
+
- **Validation Accuracy:** 50%
|
156 |
+
- **Problem:** Overfitted to "real" class due to prototype imbalance
|
157 |
+
|
158 |
+

|
159 |
+
|
160 |
+
#### Learned Prototypes
|
161 |
+
|
162 |
+

|
163 |
+
|
164 |
+
Despite underperformance, I successfully:
|
165 |
+
|
166 |
+
- Trained the ProtoPNet architecture
|
167 |
+
- Projected prototypes to most activating image patches
|
168 |
+
- Visualized top-activating examples for each prototype
|
169 |
+
|
170 |
+
Future work will address class imbalance and refine prototype usefulness.
|
171 |
+
|
172 |
+
---
|
173 |
+
|
174 |
+
## π§Ύ References
|
175 |
+
|
176 |
+
- [Grad-CAM: Visual Explanations from Deep Networks](https://arxiv.org/abs/1610.02391) - Selvaraju et al., ICCV 2017
|
177 |
+
- [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) - Ho et al., NeurIPS 2020
|
178 |
+
- [TorchCAM Library](https://frgfm.github.io/torch-cam/)
|
179 |
+
- [This Looks Like That](https://arxiv.org/abs/1806.10574) - Chen et al., 2018
|
180 |
+
- [Case-Based Interpretable DL for Mammography](https://www.nature.com/articles/s42256-021-00400-0) - Barnett et al., Nature Machine Intelligence, 2021
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.0.0
|
2 |
+
torchvision>=0.15.0
|
3 |
+
tqdm
|
4 |
+
scikit-learn
|
5 |
+
matplotlib
|
6 |
+
seaborn
|
7 |
+
torchcam
|
8 |
+
Pillow
|
9 |
+
numpy
|
10 |
+
streamlit
|
11 |
+
huggingface_hub
|
resnet18_cat.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8d751e04dbe62ec58b9559d5c35b2002b911b6b64814f3d0c8e522265efc332f
|
3 |
+
size 44790027
|