mawady's picture
test
fc84b02
raw
history blame
3.96 kB
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageFilter
import io
import time
import os
import copy
import pickle
import datetime
import urllib.request
import gradio as gr
import torch
# from mim import install
# install('mmcv-full')
# install('mmengine')
# install('mmdet')
from mmocr.apis import MMOCRInferencer
ocr = MMOCRInferencer(det='TextSnake', rec='ABINet_Vision')
# url = (
# "https://upload.wikimedia.org/wikipedia/commons/3/38/Adorable-animal-cat-20787.jpg"
# )
# path_input = "./cat.jpg"
# urllib.request.urlretrieve(url, filename=path_input)
# url = "https://upload.wikimedia.org/wikipedia/commons/4/43/Cute_dog.jpg"
# path_input = "./dog.jpg"
# urllib.request.urlretrieve(url, filename=path_input)
# model = keras_model(weights="imagenet")
# n_steps = 50
# method = "gausslegendre"
# internal_batch_size = 50
# ig = IntegratedGradients(
# model, n_steps=n_steps, method=method, internal_batch_size=internal_batch_size
# )
# def do_process(img, baseline):
# instance = image.img_to_array(img)
# instance = np.expand_dims(instance, axis=0)
# instance = preprocess_input(instance)
# preds = model.predict(instance)
# lstPreds = decode_predictions(preds, top=3)[0]
# dctPreds = {
# lstPreds[i][1]: round(float(lstPreds[i][2]), 2) for i in range(len(lstPreds))
# }
# predictions = preds.argmax(axis=1)
# if baseline == "white":
# baselines = bls = np.ones(instance.shape).astype(instance.dtype)
# img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
# elif baseline == "black":
# baselines = bls = np.zeros(instance.shape).astype(instance.dtype)
# img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
# elif baseline == "blur":
# img_flt = img.filter(ImageFilter.GaussianBlur(5))
# baselines = image.img_to_array(img_flt)
# baselines = np.expand_dims(baselines, axis=0)
# baselines = preprocess_input(baselines)
# else:
# baselines = np.random.random_sample(instance.shape).astype(instance.dtype)
# img_flt = Image.fromarray(np.uint8(np.squeeze(baselines) * 255))
# explanation = ig.explain(instance, baselines=baselines, target=predictions)
# attrs = explanation.attributions[0]
# fig, ax = visualize_image_attr(
# attr=attrs.squeeze(),
# original_image=img,
# method="blended_heat_map",
# sign="all",
# show_colorbar=True,
# title=baseline,
# plt_fig_axis=None,
# use_pyplot=False,
# )
# fig.tight_layout()
# buf = io.BytesIO()
# fig.savefig(buf)
# buf.seek(0)
# img_res = Image.open(buf)
# return img_res, img_flt, dctPreds
# input_im = gr.inputs.Image(
# shape=(224, 224), image_mode="RGB", invert_colors=False, source="upload", type="pil"
# )
# input_drop = gr.inputs.Dropdown(
# label="Baseline (default: random)",
# choices=["random", "black", "white", "blur"],
# default="random",
# type="value",
# )
# output_img = gr.outputs.Image(label="Output of Integrated Gradients", type="pil")
# output_base = gr.outputs.Image(label="Baseline image", type="pil")
# output_label = gr.outputs.Label(label="Classification results", num_top_classes=3)
# title = "XAI - Integrated gradients"
# description = "Playground: Integrated gradients for a ResNet model trained on Imagenet dataset. Tools: Alibi, TF, Gradio."
# examples = [["./cat.jpg", "blur"], ["./dog.jpg", "random"]]
# article = "<p style='text-align: center'><a href='https://github.com/mawady' target='_blank'>By Dr. Mohamed Elawady</a></p>"
# iface = gr.Interface(
# fn=do_process,
# inputs=[input_im, input_drop],
# outputs=[output_img, output_base, output_label],
# live=False,
# interpretation=None,
# title=title,
# description=description,
# article=article,
# examples=examples,
# )
# iface.launch(debug=True)