taesiri's picture
initial commit
9313bc1
raw
history blame
11 kB
import json
import os
import pickle
import random
import time
from collections import Counter
from datetime import datetime
from glob import glob
import gdown
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import streamlit as st
from PIL import Image
import SessionState
from download_utils import *
from image_utils import *
random.seed(datetime.now())
np.random.seed(int(time.time()))
NUMBER_OF_TRIALS = 20
CLASSIFIER_TAG = "CHM"
explaination_functions = [load_chm_nns, load_knn_nns]
selected_xai_tool = None
# Config
folder_to_name = {}
class_descriptions = {}
classifier_predictions = {}
selected_dataset = "Final"
root_visualization_dir = "./visualizations/"
viz_url = "https://static.taesiri.com/xai/Final.zip"
viz_archivefile = "Final.zip"
demonstration_url = "https://static.taesiri.com/xai/demonstrations.zip"
demonst_zipfile = "demonstrations.zip"
picklefile_url = "https://static.taesiri.com/xai/Task1_Results_CHM_and_EMD.pickle"
prediction_root = "./predictions/"
prediction_pickle = f"{prediction_root}predictions.pickle"
################################################
# GLOBAL VARIABLES
app_mode = ""
## Shared/Global Information
with open("imagenet-labels.json", "rb") as f:
folder_to_name = json.load(f)
with open("gloss.txt", "r") as f:
description_file = f.readlines()
class_descriptions = {l.split("\t")[0]: l.split("\t")[1] for l in description_file}
################################################
with open(prediction_pickle, "rb") as f:
classifier_predictions = pickle.load(f)
# SESSION STATE
session_state = SessionState.get(
page=1,
first_run=1,
user_feedback={},
queries=[],
is_classifier_correct={},
XAI_tool="Unselected",
)
################################################
def get_data():
download_files(
root_visualization_dir,
viz_url,
viz_archivefile,
demonstration_url,
demonst_zipfile,
picklefile_url,
prediction_root,
prediction_pickle,
)
def resmaple_queries():
if session_state.first_run == 1:
both_correct = glob(
root_visualization_dir + selected_dataset + "/Both_correct/*.JPEG"
)
both_wrong = glob(
root_visualization_dir + selected_dataset + "/Both_wrong/*.JPEG"
)
correct_samples = list(
np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False)
)
wrong_samples = list(
np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False)
)
all_images = correct_samples + wrong_samples
random.shuffle(all_images)
session_state.queries = all_images
session_state.first_run = -1
# RESET INTERACTIONS
session_state.user_feedback = {}
session_state.is_classifier_correct = {}
def render_experiment(query):
current_query = session_state.queries[query]
query_id = os.path.basename(current_query)
predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"]
prediction_confidence = classifier_predictions[query_id][
f"{CLASSIFIER_TAG}-confidence"
]
prediction_label = folder_to_name[predicted_wnid]
class_def = class_descriptions[predicted_wnid]
session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][
f"{CLASSIFIER_TAG}-Output"
]
################################### SHOW DESCRIPTION OF CLASS
with st.expander("Show Class Description"):
st.write(f"**Name**: {prediction_label}")
st.write("**Class Definition**:")
st.markdown("`" + class_def + "`")
st.image(
Image.open(f"demonstrations/{predicted_wnid}.jpeg"),
caption=f"Class Explanation",
use_column_width=True,
)
################################### SHOW QUERY and PREDICTION
with st.expander("Show Query"):
col1, col2 = st.columns(2)
with col1:
st.image(load_query(current_query), caption=f"Query ID: {query_id}")
with col2:
default_value = 0
if query_id in session_state.user_feedback.keys():
if session_state.user_feedback[query_id] == "Correct":
default_value = 1
elif session_state.user_feedback[query_id] == "Wrong":
default_value = 2
session_state.user_feedback[query_id] = st.radio(
"What do you think about model's prediction?",
("-", "Correct", "Wrong"),
key=query_id,
index=default_value,
)
st.write(f"**Model Prediction**: {prediction_label}")
st.write(f"**Model Confidence**: {prediction_confidence}")
################################### SHOW Model Explanation
if selected_xai_tool is not None:
st.image(
selected_xai_tool(current_query),
caption=f"Explaination",
use_column_width=True,
)
################################### SHOW DEBUG INFO
if st.button("Debug: Show Everything"):
st.image(Image.open(current_query))
def render_results():
user_correct_guess = 0
for q in session_state.user_feedback.keys():
if session_state.is_classifier_correct[q] == session_state.user_feedback[q]:
user_correct_guess += 1
st.write(
f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct"
)
st.markdown("## User Performance Breakdown")
categories = set(session_state.is_classifier_correct.values())
breakdown_stats_correct = {c: 0 for c in categories}
breakdown_stats_wrong = {c: 0 for c in categories}
experiment_summary = []
for q in session_state.user_feedback.keys():
category = session_state.is_classifier_correct[q]
user_feedback_boolean = (
True if session_state.user_feedback[q] == "Correct" else False
)
is_user_correct = category == user_feedback_boolean
if is_user_correct:
breakdown_stats_correct[category] += 1
else:
breakdown_stats_wrong[category] += 1
experiment_summary.append(
[
q,
classifier_predictions[q]["real-gts"],
folder_to_name[
classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"]
],
category,
session_state.user_feedback[q],
is_user_correct,
]
)
experiment_summary_df = pd.DataFrame.from_records(
experiment_summary,
columns=[
"Query",
"GT Labels",
f"{CLASSIFIER_TAG} Prediction",
"Category",
"User Prediction",
"Is User Prediction Correct",
],
)
st.write("Summary", experiment_summary_df)
csv = convert_df(experiment_summary_df)
st.download_button(
"Press to Download", csv, "summary.csv", "text/csv", key="download-records"
)
def render_menu():
# Render the readme as markdown using st.markdown.
readme_text = st.markdown(
"""
# Instructions
```
When testing this study, you should first see the class definition, then hide the expander and see the query.
```
"""
)
app_mode = st.selectbox(
"Choose the page to show:",
["Experiment Instruction", "Start Experiment", "See the Results"],
)
if app_mode == "Experiment Instruction":
st.success("To continue select an option in the dropdown menu.")
elif app_mode == "Start Experiment":
# Clear Canvas
readme_text.empty()
page_id = session_state.page
col1, col4, col2, col3 = st.columns(4)
prev_page = col1.button("Previous Image")
if prev_page:
page_id -= 1
if page_id < 1:
page_id = 1
next_page = col2.button("Next Image")
if next_page:
page_id += 1
if page_id > NUMBER_OF_TRIALS:
page_id = NUMBER_OF_TRIALS
if page_id == NUMBER_OF_TRIALS:
st.success(
'You have reached the last image. Please go to the "Results" page to see your performance.'
)
if st.button("View"):
app_mode = "See the Results"
if col3.button("Resample"):
st.write("Restarting ...")
page_id = 1
session_state.first_run = 1
resmaple_queries()
session_state.page = page_id
st.write(f"Render Experiment: {session_state.page}")
render_experiment(session_state.page - 1)
elif app_mode == "See the Results":
readme_text.empty()
st.write("Results Summary")
render_results()
def main():
global app_mode
global session_state
global selected_xai_tool
# Get the Data
get_data()
# Set the session state
# State Management and General Setup
st.set_page_config(layout="wide")
st.title("TASK - 1 - ImageNetREAL")
options = [
"Unselected",
"NOXAI",
"KNN",
"EMD Nearest Neighbors",
"EMD Correspondence",
"CHM Nearest Neighbors",
"CHM Correspondence",
]
st.markdown(
""" <style>
div[role="radiogroup"] > :first-child{
display: none !important;
}
</style>
""",
unsafe_allow_html=True,
)
if session_state.XAI_tool == "Unselected":
default = options.index(session_state.XAI_tool)
session_state.XAI_tool = st.radio(
"What explaination tool do you want to evaluate?",
options,
key="which_xai",
index=default,
)
# print(session_state.XAI_tool)
if session_state.XAI_tool != "Unselected":
st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``")
if session_state.XAI_tool == "NOXAI":
selected_xai_tool = None
CLASSIFIER_TAG = "KNN"
elif session_state.XAI_tool == "KNN":
selected_xai_tool = load_knn_nns
CLASSIFIER_TAG = "KNN"
elif session_state.XAI_tool == "CHM Nearest Neighbors":
selected_xai_tool = load_chm_nns
CLASSIFIER_TAG = "CHM"
elif session_state.XAI_tool == "CHM Correspondence":
selected_xai_tool = load_chm_corrs
CLASSIFIER_TAG = "CHM"
elif session_state.XAI_tool == "EMD Nearest Neighbors":
selected_xai_tool = load_emd_nns
CLASSIFIER_TAG = "EMD"
elif session_state.XAI_tool == "EMD Correspondence":
selected_xai_tool = load_emd_corrs
CLASSIFIER_TAG = "EMD"
resmaple_queries()
render_menu()
if __name__ == "__main__":
main()