Spaces:
Runtime error
Runtime error
File size: 3,854 Bytes
cffa665 fc84b02 cffa665 ef64c8c cffa665 46c3a5a cffa665 46c3a5a cffa665 3d6eaf1 46c3a5a cffa665 46c3a5a cffa665 46c3a5a cffa665 46c3a5a ef64c8c 46c3a5a cffa665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageFilter
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr
import torch
# from mim import install
# install('mmcv-full')
# install('mmengine')
# install('mmdet')
# from mmocr.apis import MMOCRInferencer
# ocr = MMOCRInferencer(det='TextSnake', rec='ABINet_Vision')
url = (
"https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
)
path_input = "./cat.jpg"
urllib.request.urlretrieve(url, filename=path_input)
url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
path_input = "./dog.jpg"
urllib.request.urlretrieve(url, filename=path_input)
# model = keras_model(weights="imagenet")
# n_steps = 50
# method = "gausslegendre"
# internal_batch_size = 50
# ig = IntegratedGradients(
# model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size
# )
def do_process(img):
return img
# instance = image.img_to_array(img)
# instance = np.expand_dims(instance, axis=0)
# instance = preprocess_input(instance)
# preds = model.predict(instance)
# lstPreds = decode_predictions(preds, top=3)[0]
# dctPreds = {
# lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds))
# }
# predictions = preds.argmax(axis=1)
# if baseline == "white":
# baselines = bls = np.ones(instance.shape).astype(instance.dtype)
# img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
# elif baseline == "black":
# baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
# img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
# elif baseline == "blur":
# img_flt = img.filter(ImageFilter.GaussianBlur(5))
# baselines = image.img_to_array(img_flt)
# baselines = np.expand_dims(baselines, axis=0)
# baselines = preprocess_input(baselines)
# else:
# baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
# img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
# explanation = ig.explain(instance, baselines=baselines, target=predictions)
# attrs = explanation.attributions[0]
# fig, ax = visualize_image_attr(
# attr=attrs.squeeze(),
# original_image=img,
# method="blended_heat_map",
# sign="all",
# show_colorbar=True,
# title=baseline,
# plt_fig_axis=None,
# use_pyplot=False,
# )
# fig.tight_layout()
# buf = io.BytesIO()
# fig.savefig(buf)
# buf.seek(0)
# img_res = Image.open(buf)
# return img_res, img_flt, dctPreds
input_im = gr.inputs.Image(
shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil"
)
# input_drop = gr.inputs.Dropdown(
# label="Baseline (default: random)",
# choices=["random", "black", "white", "blur"],
# default="random",
# type="value",
# )
output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil")
# output_base = gr.outputs.Image(label="Baseline image", type="pil")
# output_label = gr.outputs.Label(label="Classification results", num_top_classes=3)
title = "XAI - Integrated gradients"
description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
examples = [["./cat.jpg"], ["./dog.jpg"]]
article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
iface = gr.Interface(
fn=do_process,
inputs=[input_im],
outputs=[output_img],
live=False,
interpretation=None,
title=title,
description=description,
article=article,
examples=examples,
)
# iface.launch(debug=True)
|