import os import sys from env import config_env config_env() import gradio as gr from huggingface_hub import snapshot_download import cv2 import dotenv dotenv.load_dotenv() import numpy as np import gradio as gr import glob from inference_sam import segmentation_sam from explanations import explain from inference_resnet import get_triplet_model from inference_beit import get_triplet_model_beit import pathlib import tensorflow as tf from closest_sample import get_images if not os.path.exists('images'): REPO_ID='Serrelab/image_examples_gradio' snapshot_download(repo_id=REPO_ID, token=os.environ.get('READ_TOKEN'),repo_type='dataset',local_dir='images') if not os.path.exists('dataset'): REPO_ID='Serrelab/Fossils' token = os.environ.get('READ_TOKEN') print(f"Read token:{token}") if token is None: print("warning! A read token in env variables is needed for authentication.") snapshot_download(repo_id=REPO_ID, token=token,repo_type='dataset',local_dir='dataset') def get_model(model_name): if model_name=='Mummified 170': n_classes = 170 model = get_triplet_model(input_shape = (600, 600, 3), embedding_units = 256, embedding_depth = 2, backbone_class=tf.keras.applications.ResNet50V2, nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') model.load_weights('model_classification/mummified-170.h5') elif model_name=='Rock 170': n_classes = 171 model = get_triplet_model(input_shape = (600, 600, 3), embedding_units = 256, embedding_depth = 2, backbone_class=tf.keras.applications.ResNet50V2, nb_classes = n_classes,load_weights=False,finer_model=True,backbone_name ='Resnet50v2') model.load_weights('model_classification/rock-170.h5') elif model_name == 'Fossils 142': n_classes = 142 model = get_triplet_model_beit(input_shape = (384, 384, 3), embedding_units = 256, embedding_depth = 2, n_classes = n_classes) model.load_weights('model_classification/fossil-142.h5') else: raise ValueError(f"Model name '{model_name}' is not recognized") return model,n_classes def segment_image(input_image): img = segmentation_sam(input_image) return img def classify_image(input_image, model_name): #segmented_image = segment_image(input_image) if 'Rock 170' ==model_name: from inference_resnet import inference_resnet_finer model,n_classes= get_model(model_name) result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) return result elif 'Mummified 170' ==model_name: from inference_resnet import inference_resnet_finer model, n_classes= get_model(model_name) result = inference_resnet_finer(input_image,model,size=600,n_classes=n_classes) return result if 'Fossils 142' ==model_name: from inference_beit import inference_resnet_finer_beit model,n_classes = get_model(model_name) result = inference_resnet_finer_beit(input_image,model,size=384,n_classes=n_classes) return result return None def get_embeddings(input_image,model_name): if 'Rock 170' ==model_name: from inference_resnet import inference_resnet_embedding model,n_classes= get_model(model_name) result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) return result elif 'Mummified 170' ==model_name: from inference_resnet import inference_resnet_embedding model, n_classes= get_model(model_name) result = inference_resnet_embedding(input_image,model,size=600,n_classes=n_classes) return result if 'Fossils 142' ==model_name: from inference_beit import inference_resnet_embedding_beit model,n_classes = get_model(model_name) result = inference_resnet_embedding_beit(input_image,model,size=384,n_classes=n_classes) return result return None def find_closest(input_image,model_name): embedding = get_embeddings(input_image,model_name) classes, paths = get_images(embedding) #outputs = classes+paths return classes,paths def explain_image(input_image,model_name): model,n_classes= get_model(model_name) if model_name=='Fossils 142': size = 384 else: size = 600 #saliency, integrated, smoothgrad, rise,avg = explain(model,input_image,size = size, n_classes=n_classes) #original = saliency + integrated + smoothgrad print('done') rise1,rise2,rise3,rise4,rise5,avg = rise[0],rise[1],rise[2],rise[3],rise[4],avg[0] return rise1,rise2,rise3,rise4,rise5,avg #minimalist theme with gr.Blocks(theme='sudeepshouche/minimalist') as demo: with gr.Tab(" Florrissant Fossils"): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input") classify_image_button = gr.Button("Classify Image") # with gr.Column(): # #segmented_image = gr.outputs.Image(label="SAM output",type='numpy') # segmented_image=gr.Image(label="Segmented Image", type='numpy') # segment_button = gr.Button("Segment Image") # #classify_segmented_button = gr.Button("Classify Segmented Image") with gr.Column(): model_name = gr.Dropdown( ["Mummified 170", "Rock 170","Fossils 142"], multiselect=False, value="Fossils 142", # default option label="Model", interactive=True, ) class_predicted = gr.Label(label='Class Predicted',num_top_classes=10) with gr.Row(): paths = sorted(pathlib.Path('images/').rglob('*.jpg')) samples=[[path.as_posix()] for path in paths if 'fossils' in str(path) ][:19] examples_fossils = gr.Examples(samples, inputs=input_image,examples_per_page=10,label='Fossils Examples from the dataset') samples=[[path.as_posix()] for path in paths if 'leaves' in str(path) ][:19] examples_leaves = gr.Examples(samples, inputs=input_image,examples_per_page=5,label='Leaves Examples from the dataset') # with gr.Accordion("Using Diffuser"): # with gr.Column(): # prompt = gr.Textbox(lines=1, label="Prompt") # output_image = gr.Image(label="Output") # generate_button = gr.Button("Generate Leave") # with gr.Column(): # class_predicted2 = gr.Label(label='Class Predicted from diffuser') # classify_button = gr.Button("Classify Image") with gr.Accordion("Explanations "): gr.Markdown("Computing Explanations from the model") with gr.Row(): #original_input = gr.Image(label="Original Frame") #saliency = gr.Image(label="saliency") #gradcam = gr.Image(label='integraged gradients') #guided_gradcam = gr.Image(label='gradcam') #guided_backprop = gr.Image(label='guided backprop') rise1 = gr.Image(label = 'Rise1') rise2 = gr.Image(label = 'Rise2') rise3 = gr.Image(label = 'Rise3') rise4 = gr.Image(label = 'Rise4') rise5 = gr.Image(label = 'Rise5') avg = gr.Image(label = 'Avg') generate_explanations = gr.Button("Generate Explanations") # with gr.Accordion('Closest Images'): # gr.Markdown("Finding the closest images in the dataset") # with gr.Row(): # with gr.Column(): # label_closest_image_0 = gr.Markdown('') # closest_image_0 = gr.Image(label='Closest Image',image_mode='contain',width=200, height=200) # with gr.Column(): # label_closest_image_1 = gr.Markdown('') # closest_image_1 = gr.Image(label='Second Closest Image',image_mode='contain',width=200, height=200) # with gr.Column(): # label_closest_image_2 = gr.Markdown('') # closest_image_2 = gr.Image(label='Third Closest Image',image_mode='contain',width=200, height=200) # with gr.Column(): # label_closest_image_3 = gr.Markdown('') # closest_image_3 = gr.Image(label='Forth Closest Image',image_mode='contain', width=200, height=200) # with gr.Column(): # label_closest_image_4 = gr.Markdown('') # closest_image_4 = gr.Image(label='Fifth Closest Image',image_mode='contain',width=200, height=200) # find_closest_btn = gr.Button("Find Closest Images") with gr.Accordion('Closest Images'): gr.Markdown("Finding the closest images in the dataset") with gr.Row(): gallery = gr.Gallery(label="Closest Images", show_label=False,elem_id="gallery",columns=[5], rows=[1],height='auto', allow_preview=True, preview=None) #.style(grid=[1, 5], height=200, width=200) find_closest_btn = gr.Button("Find Closest Images") #segment_button.click(segment_image, inputs=input_image, outputs=segmented_image) classify_image_button.click(classify_image, inputs=[input_image,model_name], outputs=class_predicted) generate_explanations.click(explain_image, inputs=[input_image,model_name], outputs=[rise1,rise2,rise3,rise4,rise5,avg]) # #find_closest_btn.click(find_closest, inputs=[input_image,model_name], outputs=[label_closest_image_0,label_closest_image_1,label_closest_image_2,label_closest_image_3,label_closest_image_4,closest_image_0,closest_image_1,closest_image_2,closest_image_3,closest_image_4]) def update_outputs(input_image,model_name): labels, images = find_closest(input_image,model_name) #labels_html = "".join([f'