File size: 3,854 Bytes
cffa665
 
 
 
 
 
 
 
 
 
 
fc84b02
 
cffa665
 
 
 
 
 
 
ef64c8c
 
cffa665
46c3a5a
 
 
 
 
cffa665
46c3a5a
 
 
cffa665
 
 
 
 
 
 
 
 
 
 
3d6eaf1
46c3a5a
cffa665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46c3a5a
 
 
cffa665
 
 
 
 
 
 
46c3a5a
cffa665
 
 
46c3a5a
 
ef64c8c
46c3a5a
 
 
 
 
 
 
 
 
 
 
 
cffa665
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageFilter
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr
import torch

# from mim import install



# install('mmcv-full')
# install('mmengine')
# install('mmdet')
# from mmocr.apis import MMOCRInferencer
# ocr = MMOCRInferencer(det='TextSnake', rec='ABINet_Vision')

url = (
    "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)

url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)

# model = keras_model(weights="imagenet")

# n_steps = 50
# method = "gausslegendre"
# internal_batch_size = 50
# ig = IntegratedGradients(
#     model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size
# )


def do_process(img):
    return img
#     instance = image.img_to_array(img)
#     instance = np.expand_dims(instance, axis=0)
#     instance = preprocess_input(instance)
#     preds = model.predict(instance)
#     lstPreds = decode_predictions(preds, top=3)[0]
#     dctPreds = {
#         lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds))
#     }
#     predictions = preds.argmax(axis=1)
#     if baseline == "white":
#         baselines = bls = np.ones(instance.shape).astype(instance.dtype)
#         img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
#     elif baseline == "black":
#         baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
#         img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
#     elif baseline == "blur":
#         img_flt = img.filter(ImageFilter.GaussianBlur(5))
#         baselines = image.img_to_array(img_flt)
#         baselines = np.expand_dims(baselines, axis=0)
#         baselines = preprocess_input(baselines)
#     else:
#         baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
#         img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
#     explanation = ig.explain(instance, baselines=baselines, target=predictions)
#     attrs = explanation.attributions[0]
#     fig, ax = visualize_image_attr(
#         attr=attrs.squeeze(),
#         original_image=img,
#         method="blended_heat_map",
#         sign="all",
#         show_colorbar=True,
#         title=baseline,
#         plt_fig_axis=None,
#         use_pyplot=False,
#     )
#     fig.tight_layout()
#     buf = io.BytesIO()
#     fig.savefig(buf)
#     buf.seek(0)
#     img_res = Image.open(buf)
#     return img_res, img_flt, dctPreds


input_im = gr.inputs.Image(
    shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil"
)
# input_drop = gr.inputs.Dropdown(
#     label="Baseline (default: random)",
#     choices=["random", "black", "white", "blur"],
#     default="random",
#     type="value",
# )

output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil")
# output_base = gr.outputs.Image(label="Baseline image", type="pil")
# output_label = gr.outputs.Label(label="Classification results", num_top_classes=3)

title = "XAI - Integrated gradients"
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
examples = [["./cat.jpg"], ["./dog.jpg"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
    fn=do_process,
    inputs=[input_im],
    outputs=[output_img],
    live=False,
    interpretation=None,
    title=title,
    description=description,
    article=article,
    examples=examples,
)

# iface.launch(debug=True)