Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	Commit 
							
							·
						
						9ff2b84
	
0
								Parent(s):
							
							
taxabind-demo
Browse files- .gitattributes +40 -0
 - README.md +14 -0
 - app.py +125 -0
 - coordinates_new.npy +3 -0
 - loc_embeds_new.npy +3 -0
 - requirements.txt +6 -0
 - sat_embeds_new.npy +3 -0
 - txt_emb_species.json +3 -0
 - txt_emb_species.npy +3 -0
 
    	
        .gitattributes
    ADDED
    
    | 
         @@ -0,0 +1,40 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            *.7z filter=lfs diff=lfs merge=lfs -text
         
     | 
| 2 | 
         
            +
            *.arrow filter=lfs diff=lfs merge=lfs -text
         
     | 
| 3 | 
         
            +
            *.bin filter=lfs diff=lfs merge=lfs -text
         
     | 
| 4 | 
         
            +
            *.bz2 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 5 | 
         
            +
            *.ckpt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 6 | 
         
            +
            *.ftz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 7 | 
         
            +
            *.gz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 8 | 
         
            +
            *.h5 filter=lfs diff=lfs merge=lfs -text
         
     | 
| 9 | 
         
            +
            *.joblib filter=lfs diff=lfs merge=lfs -text
         
     | 
| 10 | 
         
            +
            *.lfs.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 11 | 
         
            +
            *.mlmodel filter=lfs diff=lfs merge=lfs -text
         
     | 
| 12 | 
         
            +
            *.model filter=lfs diff=lfs merge=lfs -text
         
     | 
| 13 | 
         
            +
            *.msgpack filter=lfs diff=lfs merge=lfs -text
         
     | 
| 14 | 
         
            +
            *.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 15 | 
         
            +
            *.npz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 16 | 
         
            +
            *.onnx filter=lfs diff=lfs merge=lfs -text
         
     | 
| 17 | 
         
            +
            *.ot filter=lfs diff=lfs merge=lfs -text
         
     | 
| 18 | 
         
            +
            *.parquet filter=lfs diff=lfs merge=lfs -text
         
     | 
| 19 | 
         
            +
            *.pb filter=lfs diff=lfs merge=lfs -text
         
     | 
| 20 | 
         
            +
            *.pickle filter=lfs diff=lfs merge=lfs -text
         
     | 
| 21 | 
         
            +
            *.pkl filter=lfs diff=lfs merge=lfs -text
         
     | 
| 22 | 
         
            +
            *.pt filter=lfs diff=lfs merge=lfs -text
         
     | 
| 23 | 
         
            +
            *.pth filter=lfs diff=lfs merge=lfs -text
         
     | 
| 24 | 
         
            +
            *.rar filter=lfs diff=lfs merge=lfs -text
         
     | 
| 25 | 
         
            +
            *.safetensors filter=lfs diff=lfs merge=lfs -text
         
     | 
| 26 | 
         
            +
            saved_model/**/* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 27 | 
         
            +
            *.tar.* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 28 | 
         
            +
            *.tar filter=lfs diff=lfs merge=lfs -text
         
     | 
| 29 | 
         
            +
            *.tflite filter=lfs diff=lfs merge=lfs -text
         
     | 
| 30 | 
         
            +
            *.tgz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 31 | 
         
            +
            *.wasm filter=lfs diff=lfs merge=lfs -text
         
     | 
| 32 | 
         
            +
            *.xz filter=lfs diff=lfs merge=lfs -text
         
     | 
| 33 | 
         
            +
            *.zip filter=lfs diff=lfs merge=lfs -text
         
     | 
| 34 | 
         
            +
            *.zst filter=lfs diff=lfs merge=lfs -text
         
     | 
| 35 | 
         
            +
            *tfevents* filter=lfs diff=lfs merge=lfs -text
         
     | 
| 36 | 
         
            +
            txt_emb_species.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 37 | 
         
            +
            txt_emb_species.json filter=lfs diff=lfs merge=lfs -text
         
     | 
| 38 | 
         
            +
            coordinates_new.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 39 | 
         
            +
            loc_embeds_new.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
| 40 | 
         
            +
            sat_embeds_new.npy filter=lfs diff=lfs merge=lfs -text
         
     | 
    	
        README.md
    ADDED
    
    | 
         @@ -0,0 +1,14 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            ---
         
     | 
| 2 | 
         
            +
            title: Taxabind Demo
         
     | 
| 3 | 
         
            +
            emoji: 🔥
         
     | 
| 4 | 
         
            +
            colorFrom: indigo
         
     | 
| 5 | 
         
            +
            colorTo: indigo
         
     | 
| 6 | 
         
            +
            sdk: gradio
         
     | 
| 7 | 
         
            +
            sdk_version: 5.4.0
         
     | 
| 8 | 
         
            +
            app_file: app.py
         
     | 
| 9 | 
         
            +
            pinned: false
         
     | 
| 10 | 
         
            +
            license: apache-2.0
         
     | 
| 11 | 
         
            +
            short_description: 'TaxaBind: A Unified Embedding Space for Ecological Applicati'
         
     | 
| 12 | 
         
            +
            ---
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
            Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,125 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import gradio as gr
         
     | 
| 2 | 
         
            +
            import numpy as np
         
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            from torchvision import transforms
         
     | 
| 5 | 
         
            +
            import open_clip
         
     | 
| 6 | 
         
            +
            import pymap3d as pm
         
     | 
| 7 | 
         
            +
            import reverse_geocode
         
     | 
| 8 | 
         
            +
            import json
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            def bounding_box_from_circle(lat_center, lon_center, radius = 1000,
         
     | 
| 12 | 
         
            +
                disable_latitude_compensation=False):
         
     | 
| 13 | 
         
            +
              '''
         
     | 
| 14 | 
         
            +
              radius is in meters determined at the equator
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
              warning: doesn't handle the poles or the 180th meridian very well, it might loop give a bad bounding box
         
     | 
| 17 | 
         
            +
               should probably define a check to make sure the radius isn't too big
         
     | 
| 18 | 
         
            +
              '''
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
              thetas = np.linspace(0,2*np.pi, 5)
         
     | 
| 21 | 
         
            +
              x, y = radius*np.cos(thetas), radius*np.sin(thetas)
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
              if not disable_latitude_compensation:
         
     | 
| 25 | 
         
            +
                # use tangent plane boxes, defined in meters at location
         
     | 
| 26 | 
         
            +
                lat, lon, alt = pm.enu2geodetic(x, y, 0, lat_center, lon_center, 0)
         
     | 
| 27 | 
         
            +
              else:
         
     | 
| 28 | 
         
            +
                # use lat-lon boxes, defined in meters at equator
         
     | 
| 29 | 
         
            +
                lat, lon, alt = pm.enu2geodetic(x, y, 0, 0, 0, 0)
         
     | 
| 30 | 
         
            +
                lat = lat + lat_center
         
     | 
| 31 | 
         
            +
                lon = lon + lon_center
         
     | 
| 32 | 
         
            +
             
     | 
| 33 | 
         
            +
              b,t = lat[3], lat[1]
         
     | 
| 34 | 
         
            +
              l,r = lon[2], lon[0]
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
              return l,b,r,t
         
     | 
| 37 | 
         
            +
             
     | 
| 38 | 
         
            +
            #imgs = np.load("imgs.npy")
         
     | 
| 39 | 
         
            +
            sat_embeds = np.load("sat_embeds_new.npy")
         
     | 
| 40 | 
         
            +
            coordinates = np.load("coordinates_new.npy")
         
     | 
| 41 | 
         
            +
            loc_embeds = np.load("loc_embeds_new.npy")
         
     | 
| 42 | 
         
            +
            txt_emb = torch.from_numpy(np.load("txt_emb_species.npy", mmap_mode="r"))
         
     | 
| 43 | 
         
            +
            txt_names_json = "txt_emb_species.json"
         
     | 
| 44 | 
         
            +
            with open(txt_names_json) as f:
         
     | 
| 45 | 
         
            +
                txt_names = json.load(f)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
            transform = transforms.Compose([
         
     | 
| 49 | 
         
            +
                transforms.Resize((256,256)),
         
     | 
| 50 | 
         
            +
                transforms.CenterCrop((224, 224)),
         
     | 
| 51 | 
         
            +
                transforms.ToTensor(),
         
     | 
| 52 | 
         
            +
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
         
     | 
| 53 | 
         
            +
                                     std=[0.229, 0.224, 0.225])
         
     | 
| 54 | 
         
            +
            ])
         
     | 
| 55 | 
         
            +
            model, *_ = open_clip.create_model_and_transforms('hf-hub:MVRL/taxabind-vit-b-16')
         
     | 
| 56 | 
         
            +
            model.eval()
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            def format_name(taxon, common):
         
     | 
| 59 | 
         
            +
                taxon = " ".join(taxon)
         
     | 
| 60 | 
         
            +
                if not common:
         
     | 
| 61 | 
         
            +
                    return taxon
         
     | 
| 62 | 
         
            +
                return f"{taxon} ({common})"
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            def process(input_image):
         
     | 
| 65 | 
         
            +
                img_tensor = transform(input_image).unsqueeze(0)
         
     | 
| 66 | 
         
            +
                with torch.no_grad():
         
     | 
| 67 | 
         
            +
                    img_embed = model(img_tensor)[0].detach().cpu()
         
     | 
| 68 | 
         
            +
                sims = torch.matmul(torch.tensor(sat_embeds), img_embed.t())
         
     | 
| 69 | 
         
            +
                sims_locs = torch.matmul(torch.nn.functional.normalize(torch.tensor(loc_embeds), dim=-1), img_embed.t())
         
     | 
| 70 | 
         
            +
                sims_txt = torch.matmul(torch.tensor(txt_emb).t(), img_embed.t())
         
     | 
| 71 | 
         
            +
                topk = torch.topk(sims, 5, dim=0)
         
     | 
| 72 | 
         
            +
                topk_locs = torch.topk(sims_locs, 5, dim=0)
         
     | 
| 73 | 
         
            +
                topk_txt = torch.topk(sims_txt, 5, dim=0)
         
     | 
| 74 | 
         
            +
                
         
     | 
| 75 | 
         
            +
                images = []
         
     | 
| 76 | 
         
            +
                d = {}
         
     | 
| 77 | 
         
            +
                d_species = {}
         
     | 
| 78 | 
         
            +
                for i in range(5):
         
     | 
| 79 | 
         
            +
                    lat, lon = coordinates[topk.indices[i]]
         
     | 
| 80 | 
         
            +
                    l,b,r,t = bounding_box_from_circle(float(lat),float(lon),1280,
         
     | 
| 81 | 
         
            +
                        disable_latitude_compensation=True)
         
     | 
| 82 | 
         
            +
                    image_url = f"https://tiles.maps.eox.at/wms?service=WMS&version=1.1.1&request=GetMap&layers=s2cloudless-2020&styles=&width=256&height=256&srs=EPSG:4326&bbox={l},{b},{r},{t}&format=image/png"
         
     | 
| 83 | 
         
            +
                    images.append(image_url)
         
     | 
| 84 | 
         
            +
                    code = reverse_geocode.get([lat, lon])
         
     | 
| 85 | 
         
            +
                    d.update({f"{code['city']}, {code['country']} ({lat:.4f}, {lon:.4f})": topk.values[i].item()})
         
     | 
| 86 | 
         
            +
                    d_species.update({f"{format_name(*txt_names[topk_txt.indices[i]])}": topk_txt.values[i].item()})
         
     | 
| 87 | 
         
            +
                return d_species, d, [(np.array(input_image), "Query Image")] + [(images[i], f"Result {i+1}") for i in range(5)]
         
     | 
| 88 | 
         
            +
             
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            block = gr.Blocks().queue()
         
     | 
| 91 | 
         
            +
            with block:
         
     | 
| 92 | 
         
            +
                with gr.Row():
         
     | 
| 93 | 
         
            +
                    gr.Markdown(
         
     | 
| 94 | 
         
            +
                    """
         
     | 
| 95 | 
         
            +
                    <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
         
     | 
| 96 | 
         
            +
                      <div>
         
     | 
| 97 | 
         
            +
                        <h1>TaxaBind</h1>
         
     | 
| 98 | 
         
            +
                        <span>A Unified Embedding Space for Ecological Applications</span>
         
     | 
| 99 | 
         
            +
                        <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
         
     | 
| 100 | 
         
            +
                            <a href="https://vishu26.github.io/">Srikumar Sastry</a>,
         
     | 
| 101 | 
         
            +
                            <a href="https://subash-khanal.github.io/">Subash Khanal</a>,
         
     | 
| 102 | 
         
            +
                            <a href="https://sites.wustl.edu/aayush/">Aayush Dhakal</a>,
         
     | 
| 103 | 
         
            +
                            <a href="https://adealgis.wixsite.com/adeel-ahmad-geog">Adeel Ahmad</a>,
         
     | 
| 104 | 
         
            +
                            <a href="https://jacobsn.github.io/">Nathan Jacobs</a>
         
     | 
| 105 | 
         
            +
                        </h2>
         
     | 
| 106 | 
         
            +
                        <h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>WACV 2025</h2>
         
     | 
| 107 | 
         
            +
                      </div>
         
     | 
| 108 | 
         
            +
                    </div>
         
     | 
| 109 | 
         
            +
                    """
         
     | 
| 110 | 
         
            +
                    )
         
     | 
| 111 | 
         
            +
                with gr.Row():
         
     | 
| 112 | 
         
            +
                    with gr.Column():
         
     | 
| 113 | 
         
            +
                        input_image = gr.Image(sources='upload', type="pil", height=400)
         
     | 
| 114 | 
         
            +
                        run_button = gr.Button(value="Run")
         
     | 
| 115 | 
         
            +
                    with gr.Column():
         
     | 
| 116 | 
         
            +
                        species = gr.Label(label="Species Classification", num_top_classes=5, show_label=True)
         
     | 
| 117 | 
         
            +
                with gr.Row():
         
     | 
| 118 | 
         
            +
                    with gr.Column(): 
         
     | 
| 119 | 
         
            +
                        coords = gr.Label(label="Image -> Location Retrieval (Top 5)", num_top_classes=5, show_label=True)
         
     | 
| 120 | 
         
            +
                    with gr.Column():
         
     | 
| 121 | 
         
            +
                        result_gallery = gr.Gallery(label='Image -> Satellite Image Retrieval (Top 5)', elem_id="gallery", object_fit="contain", height="auto", columns=[2], rows=[3])
         
     | 
| 122 | 
         
            +
                ips = [input_image]
         
     | 
| 123 | 
         
            +
                run_button.click(fn=process, inputs=ips, outputs=[species, coords, result_gallery])
         
     | 
| 124 | 
         
            +
             
     | 
| 125 | 
         
            +
            block.launch(share=True)
         
     | 
    	
        coordinates_new.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:ac50785e868d4a98fb866b331f52d42bab559b976c4ee607e112fec5b6806307
         
     | 
| 3 | 
         
            +
            size 797328
         
     | 
    	
        loc_embeds_new.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:ec310669f4e077ba614e476c36eff7934359b87a93d4c4f9e4f32e66f0023602
         
     | 
| 3 | 
         
            +
            size 204083328
         
     | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,6 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            numpy
         
     | 
| 2 | 
         
            +
            reverse_geocode
         
     | 
| 3 | 
         
            +
            pymap3d
         
     | 
| 4 | 
         
            +
            torch
         
     | 
| 5 | 
         
            +
            open_clip
         
     | 
| 6 | 
         
            +
            torchvision
         
     | 
    	
        sat_embeds_new.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:065752a12952c2f2fdb641466b2cc6722e3d1213cfa3f7a18c7dd97fe6d034f9
         
     | 
| 3 | 
         
            +
            size 204083328
         
     | 
    	
        txt_emb_species.json
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:844e6fabc06cac072214d566b78f40825b154efa9479eb11285030ca038b2ece
         
     | 
| 3 | 
         
            +
            size 65731052
         
     | 
    	
        txt_emb_species.npy
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:91ce02dff2433222e3138b8bf7eefa1dd74b30f4d406c16cd3301f66d65ab4ed
         
     | 
| 3 | 
         
            +
            size 787435648
         
     |