jeremyLE-Ekimetrics commited on
Commit
5c718d1
·
0 Parent(s):

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+
5
+ # C extensions
6
+ *.so
7
+
8
+ # Distribution / packaging
9
+ .Python
10
+ env/
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # PyInstaller
27
+ # Usually these files are written by a python script from a template
28
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
29
+ *.manifest
30
+ *.spec
31
+
32
+ # Installer logs
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Unit test / coverage reports
37
+ htmlcov/
38
+ .tox/
39
+ .coverage
40
+ .coverage.*
41
+ .cache
42
+ nosetests.xml
43
+ coverage.xml
44
+ *.cover
45
+
46
+ # Translations
47
+ *.mo
48
+ *.pot
49
+
50
+ # Django stuff:
51
+ *.log
52
+
53
+ # Sphinx documentation
54
+ docs/_build/
55
+
56
+ # PyBuilder
57
+ target/
58
+
59
+ # DotEnv configuration
60
+ .env
61
+
62
+ # Database
63
+ *.db
64
+ *.rdb
65
+
66
+ # Pycharm
67
+ .idea
68
+
69
+ # VS Code
70
+ .vscode/
71
+
72
+ # Spyder
73
+ .spyproject/
74
+
75
+ # Jupyter NB Checkpoints
76
+ .ipynb_checkpoints/
77
+
78
+ # Mac OS-specific storage files
79
+ .DS_Store
80
+
81
+ # vim
82
+ *.swp
83
+ *.swo
84
+
85
+ # Mypy cache
86
+ .mypy_cache/
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+ COPY requirements.txt /app/requirements.txt
3
+ WORKDIR /app
4
+ RUN pip install seaborn
5
+ RUN pip install gradio
6
+ RUN pip install datetime
7
+ RUN pip install numpy
8
+ RUN pip install opencv-python
9
+ RUN apt-get update
10
+ RUN apt-get install ffmpeg libsm6 libxext6 -y
11
+ RUN pip install -r requirements.txt
12
+ COPY . /app
13
+ # EXPOSE 7860
14
+ CMD python app.py
15
+ # hello world
README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Welcome to the project inno-satellite-images-segmentation-gan
2
+ ![](docs/assets/banner.png)
3
+
4
+ - **Project name**: inno-satellite-images-segmentation-gan
5
+ - **Library name**: library
6
+ - **Authors**: Ekimetrics
7
+ - **Description**: Segmenting satellite images in a large scale is challenging because grondtruth labels are spurious for medium resolution images (Sentinel 2). We want to improve our algorithm either with data augmentation from a GAN, or to correct or adjust Corine labels.
8
+
9
+
10
+
11
+ ## Project Structure
12
+ ```
13
+ - library/ # Your python library
14
+ - data/
15
+ - raw/
16
+ - processed/
17
+ - docs/
18
+ - tests/ # Where goes each unitary test in your folder
19
+ - scripts/ # Where each automation script will go
20
+ - requirements.txt # Where you should put the libraries version used in your library
21
+ ```
22
+
23
+
24
+ ## Branch strategy
25
+ TBD
26
+
27
+
28
+ ## Ethics checklist
29
+ TBD
30
+
31
+
32
+
33
+ ## Starter package
34
+ This project has been created using the Ekimetrics Python Starter Package to enforce best coding practices, reusability and industrialization. <br>
35
+ If you have any questions please reach out to the inno team and [Théo Alves Da Costa](mailto:[email protected])
36
+
37
+
38
+
39
+
40
+
biomap/.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ #wsl
2
+ *.Zone.Identifier
3
+
4
+ #python
5
+ *__pycache__
6
+
biomap/.private-key.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "service_account",
3
+ "project_id": "cvimg-377115",
4
+ "private_key_id": "a162152bd26f4bcc287c44b130109892b5517875",
5
+ "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDCr/zwOTwyVVdF\nk1cKcdju9jeDWceVJ/zj73b5IS8IbQJbFGWSjaL6Ft1HfJ3rSdGbj+Xy26jY9OFJ\n4rIhpY0M0cCWpYo+gqS9p4JAL6lHqZvSnkRTglpx9QYOT8o9ibCWhMVVAPH71QZ/\n4BEfGC6s2+zdEn+cbCkGIqLLhZTq655kDOaGSycwV/bk+TOLI/An4gjMoEIimhsD\nS6TRmqZnQGoI6m6aj3xPZGVMkid3I37h+BOC64YjKeXnpAhTQ4LbpQz1BxvIKDt6\ncJ1FBOmdSvUC+dLEGMN3yijpJZ74nXUnSVbwYYt3K8Kz2PtmgDYvswER53NlrMoW\n1AF9ImDFAgMBAAECggEAASO7IXWBo90d6CiEdhQF9uFy2S5d9ol9N6EBl+JHOuwB\nwne4tSZJ3jT/Rus7wvX67tXI9dgf3fIgjsv92NLnwEn1Wq/xQppMm8iyK1DKU3vH\n8xrvf8iG048ojGqQXqf0ZEUoWd+/YDGZ2qNuZjmVgwZKwF2h2pcnQ25uIvWdYHrb\n3XhYLDAROVTTtyscYcl8UKmAZ35moVVBQxdakGunYg6o/s6rESRbc+gCyqHR5v+r\nCl3Z4XEKDdukIVI72Ybk0F8eZpQztN97uzK/zm9jl4NmAPXrnWLEwuJdwdm1cWUF\n/LTTuNPmRzCm7IGUpkx0AKEs6s0BRbJbwlZaj4QVJwKBgQDjb2rSO6kRLRHyv+4w\ny/OLmqOrMY7fpSCj0mH41GhiZUaoZqhDznmuhqEjo1kVipfuW94Fn5NsGmWpbmEC\nJlObUEg1umX/ceOJrtRdY3AQMSQXR6u7oc2mTgj3Opd0V1L1Lopj4Ijj43ARg/fU\nu4RnrCGHcXXzT2LCchY0ZhLg3wKBgQDbI6bzt/RNW8+IGKCvLLi41bxM/9r83GNO\nQI4a6yTT09N11okjP9h00JKYBgU3fYivA1aBloFB4kOYaBzomfWSZHEyyFWCr9y0\ndGyIDbfUaI/jFx2CaKomLnPDF5LA3IWHAsTRZ/c1JGhiOUseEq/TR0cJAo69kgf0\nkVmoGjo+2wKBgQCo7crkGJg9P8LDEbgz2mktWlETCR5cE2SpCczna62U2DChSI7W\nvng3H5x0whGbJHQxAV9pwdtYQksci/XWCO20wO7BqY+1KrydOZRXQVKtVDLAb+Wo\n2kfLrM6QA58XNP1TS5xTDyXeTsKg3+qmwhlYf8vvtGCttltenirMBL0k9QKBgFpL\nanNqDOQDPJQbcbo8dzDSAPDJS/Z86P5JY0R8N4SA99TKPV+k4w/fEUhK0sN2mmdi\nvLZQyZnYHXojDCZbqfBUKsB+A54B0LMadc3puSFwpDkyQRqG/fUVluWARRvqwapL\n3cVbTWU8RzaR3P3bPU+VQxPXVfGOxnBjo8m8ZNuZAoGBANTC20T9rZ9Won9FCbi3\nSMkGY59smx19CdytQ2rjGFOEeAVMVotP5viXFuKfv5g2E/kvyJUjuOoulA3dxddN\nQzXnOIT3dlyBmvXkHJHUIKiidyuX4JqQFdPTAmkt6KaTceRNb7VN1OqIk1AJ1SDb\nkGxerLg4WuGfSqOIV0Wk4cLI\n-----END PRIVATE KEY-----\n",
6
+ "client_email": "[email protected]",
7
+ "client_id": "115144831673857322488",
8
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
9
+ "token_uri": "https://oauth2.googleapis.com/token",
10
+ "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
11
+ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/cvimg-355%40cvimg-377115.iam.gserviceaccount.com"
12
+ }
biomap/app.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from plot_functions import *
2
+ import hydra
3
+
4
+ import torch
5
+ from model import LitUnsupervisedSegmenter
6
+ from helper import inference_on_location_and_month, inference_on_location
7
+ from plot_functions import segment_region
8
+
9
+ from functools import partial
10
+ import gradio as gr
11
+ import logging
12
+
13
+ import geopandas as gpd
14
+ mapbox_access_token = "pk.eyJ1IjoiamVyZW15LWVraW1ldHJpY3MiLCJhIjoiY2xrNjBwNGU2MDRhMjNqbWw0YTJrbnpvNCJ9.poVyIzhJuJmD6ffrL9lm2w"
15
+ geo_df = gpd.read_file(gpd.datasets.get_path('naturalearth_cities'))
16
+
17
+ def get_geomap(long, lat ):
18
+
19
+
20
+ fig = go.Figure(go.Scattermapbox(
21
+ lat=geo_df.geometry.y,
22
+ lon=geo_df.geometry.x,
23
+ mode='markers',
24
+ marker=go.scattermapbox.Marker(
25
+ size=14
26
+ ),
27
+ text=geo_df.name,
28
+ ))
29
+
30
+ fig.add_trace(go.Scattermapbox(lat=[lat],
31
+ lon=[long],
32
+ mode='markers',
33
+ marker=go.scattermapbox.Marker(
34
+ size=14
35
+ ),
36
+ marker_color="green",
37
+ text=['Actual position']))
38
+
39
+ fig.update_layout(
40
+ showlegend=False,
41
+ hovermode='closest',
42
+ mapbox=dict(
43
+ accesstoken=mapbox_access_token,
44
+ center=go.layout.mapbox.Center(
45
+ lat=lat,
46
+ lon=long
47
+ ),
48
+ zoom=3
49
+ )
50
+ )
51
+
52
+ return fig
53
+
54
+
55
+ if __name__ == "__main__":
56
+
57
+
58
+ logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
59
+ # Initialize hydra with configs
60
+ #hydra.initialize(config_path="configs", job_name="corine")
61
+ cfg = hydra.compose(config_name="my_train_config.yml")
62
+ logging.info(f"config : {cfg}")
63
+ # Load the model
64
+
65
+ nbclasses = cfg.dir_dataset_n_classes
66
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
67
+ logging.info(f"Model Initialiazed")
68
+
69
+ model_path = "checkpoint/model/model.pt"
70
+ saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
71
+ logging.info(f"Model weights Loaded")
72
+ model.load_state_dict(saved_state_dict)
73
+ logging.info(f"Model Loaded")
74
+ # css=".VIDEO video{height: 100%;width:50%;margin:auto};.VIDEO{height: 50%;};.svelte-1vnmhm4{height:auto}"
75
+ with gr.Blocks() as demo:
76
+ gr.Markdown("Estimate Biodiversity in the world.")
77
+ with gr.Tab("Single Image"):
78
+ with gr.Row():
79
+ input_map = gr.Plot().style()
80
+ with gr.Column():
81
+ input_latitude = gr.Number(label="lattitude", value=2.98)
82
+ input_longitude = gr.Number(label="longitude", value=48.81)
83
+ input_date = gr.Textbox(label="start_date", value="2020-03-20")
84
+
85
+ single_button = gr.Button("Predict")
86
+ with gr.Row():
87
+ raw_image = gr.Image(label = "Localisation visualization")
88
+ output_image = gr.Image(label = "Labeled visualisation")
89
+ score_biodiv = gr.Number(label = "Biodiversity score")
90
+
91
+ with gr.Tab("TimeLapse"):
92
+ with gr.Row():
93
+ input_map_2 = gr.Plot().style()
94
+ with gr.Row():
95
+ timelapse_input_latitude = gr.Number(value=2.98, label="Latitude")
96
+ timelapse_input_longitude = gr.Number(value=48.81, label="Longitude")
97
+ timelapse_start_date = gr.Textbox(value='2020-05-01', label="Start Date")
98
+ timelapse_end_date = gr.Textbox(value='2020-06-30', label="End Date")
99
+ segmentation = gr.CheckboxGroup(choices=['month', 'year', '2months'], value=['month'], label="Select Segmentation Level:")
100
+ timelapse_button = gr.Button(value="Predict")
101
+ map = gr.Plot().style()
102
+
103
+ demo.load(get_geomap, [input_latitude, input_longitude], input_map)
104
+ single_button.click(get_geomap, [input_latitude, input_longitude], input_map)
105
+ single_button.click(partial(inference_on_location_and_month, model), inputs=[input_latitude, input_longitude, input_date], outputs=[raw_image, output_image,score_biodiv])
106
+
107
+ demo.load(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
108
+ timelapse_button.click(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
109
+ timelapse_button.click(segment_region, inputs=[timelapse_input_latitude, timelapse_input_longitude, timelapse_start_date, timelapse_end_date,segmentation], outputs=[map])
110
+ demo.launch(share=True)
biomap/configs/my_train_config.yml ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_root: '../'
2
+ pytorch_data_dir: '/home/duong_nguyen/pytorch-data'
3
+ experiment_name: "unet_7classes"
4
+ log_dir: "france"
5
+ # experiment_name: "unet"
6
+ # log_dir: "potsdam"
7
+ azureml_logging: False
8
+ submitting_to_aml: False
9
+ full_name: ~
10
+
11
+ # Loader params
12
+ num_workers: 24
13
+ max_steps: 80000
14
+ batch_size: 16
15
+
16
+ num_neighbors: 7
17
+ dataset_name: "directory"
18
+ # dataset_name: "potsdam"
19
+
20
+ # Used if dataset_name is "directory"
21
+ dir_dataset_name: "corine"
22
+ dir_dataset_n_classes: 7
23
+
24
+ has_labels: False
25
+ # crop_type: "five"
26
+ crop_type: ~
27
+ crop_ratio: .5
28
+ res: 224
29
+ loader_crop_type: "center"
30
+
31
+ # Model Params
32
+ extra_clusters: 0
33
+ use_true_labels: False
34
+ use_recalibrator: False
35
+ model_type: "vit_small"
36
+ arch: "dino"
37
+ use_fit_model: False
38
+ dino_feat_type: "feat"
39
+ projection_type: "nonlinear"
40
+ #projection_type: linear
41
+ dino_patch_size: 8
42
+ granularity: 1
43
+ continuous: True
44
+ dim: 70
45
+ dropout: True
46
+ zero_clamp: True
47
+
48
+ lr: 5e-4
49
+ pretrained_weights: ~
50
+ use_salience: False
51
+ stabalize: False
52
+ stop_at_zero: True
53
+
54
+ # Feature Contrastive params
55
+ pointwise: True
56
+ feature_samples: 11
57
+ neg_samples: 5
58
+ aug_alignment_weight: 0.0
59
+
60
+ correspondence_weight: 1.0
61
+
62
+
63
+ # # Corine vit small 24/11/22
64
+ neg_inter_weight: 0.63
65
+ pos_inter_weight: 0.25
66
+ pos_intra_weight: 0.67
67
+ neg_inter_shift: 0.46
68
+ pos_inter_shift: 0.02
69
+ pos_intra_shift: 0.08
70
+
71
+ # # Corine vit small 11/09/22
72
+ # neg_inter_weight: 0.63
73
+ # pos_inter_weight: 0.25
74
+ # pos_intra_weight: 0.67
75
+ # neg_inter_shift: 0.46
76
+ # pos_inter_shift: 0.24
77
+ # pos_intra_shift: 0.36
78
+
79
+ # # IAROA vit small 1/31/22
80
+ # neg_inter_weight: 0.63
81
+ # pos_inter_weight: 0.25
82
+ # pos_intra_weight: 0.67
83
+ # neg_inter_shift: 0.46
84
+ # pos_inter_shift: 0.12
85
+ # pos_intra_shift: 0.18
86
+
87
+ # Potsdam vit small 1/31/22
88
+ # neg_inter_weight: 0.63
89
+ # pos_inter_weight: 0.25
90
+ # pos_intra_weight: 0.67
91
+ # neg_inter_shift: 0.46
92
+ # pos_inter_shift: 0.02
93
+ # pos_intra_shift: 0.08
94
+
95
+ # Cocostuff27 vit small 1/31/22
96
+ #neg_inter_weight: 0.63
97
+ #pos_inter_weight: 0.25
98
+ #pos_intra_weight: 0.67
99
+ #neg_inter_shift: 0.66
100
+ #pos_inter_shift: 0.02
101
+ #pos_intra_shift: 0.08
102
+
103
+
104
+ ## Cocostuff27 10/3 vit_base
105
+
106
+ #neg_inter_weight: 0.1538476246415498
107
+ #pos_inter_weight: 1
108
+ #pos_intra_weight: 0.1
109
+ #
110
+ #neg_inter_shift: 1
111
+ #pos_inter_shift: 0.2
112
+ #pos_intra_shift: 0.12
113
+
114
+
115
+ ## Cocostuff27 10/3 vit_small
116
+ #neg_inter_weight: .63
117
+ #pos_inter_weight: .25
118
+ #pos_intra_weight: .67
119
+ #
120
+ #neg_inter_shift: .16
121
+ #pos_inter_shift: .02
122
+ #pos_intra_shift: .08
123
+
124
+
125
+
126
+ ## Cocostuff27 10/3 moco
127
+ #neg_inter_weight: .63
128
+ #pos_inter_weight: .25
129
+ #pos_intra_weight: .67
130
+ #
131
+ #neg_inter_shift: .26
132
+ #pos_inter_shift: .36
133
+ #pos_intra_shift: .32
134
+
135
+ #pos_inter_shift: .12
136
+ #pos_intra_shift: .18
137
+
138
+ ## Cocostuff27
139
+ #neg_inter_weight: .72
140
+ #pos_inter_weight: .80
141
+ #pos_intra_weight: .29
142
+ #
143
+ #neg_inter_shift: .86
144
+ #pos_inter_shift: .04
145
+ #pos_intra_shift: .34
146
+
147
+ # Cityscapes 10/3
148
+
149
+ # neg_inter_weight: 0.9058762625226623
150
+ # pos_inter_weight: 0.577453483136995
151
+ # pos_intra_weight: 1
152
+
153
+ # neg_inter_shift: 0.31361241889448443
154
+ # pos_inter_shift: 0.1754346515479633
155
+ # pos_intra_shift: 0.45828472207
156
+
157
+
158
+ # Cityscapes
159
+ #neg_inter_weight: .72
160
+ #pos_inter_weight: .18
161
+ #pos_intra_weight: .46
162
+ #
163
+ #neg_inter_shift: .25
164
+ #pos_inter_shift: .20
165
+ #pos_intra_shift: .25
166
+
167
+
168
+ rec_weight: 0.0
169
+ repulsion_weight: 0.0
170
+
171
+ # CRF Params
172
+ crf_weight: 0.0
173
+ alpha: .5
174
+ beta: .15
175
+ gamma: .05
176
+ w1: 10.0
177
+ w2: 3.0
178
+ shift: 0.00
179
+ crf_samples: 1000
180
+ color_space: "rgb"
181
+
182
+ reset_probe_steps: ~
183
+
184
+ # Logging params
185
+ n_images: 5
186
+ scalar_log_freq: 10
187
+ checkpoint_freq: 50
188
+ val_freq: 100
189
+ hist_freq: 100
190
+
191
+
192
+ hydra:
193
+ run:
194
+ dir: "."
195
+ output_subdir: ~
196
+ #job_logging: "disabled"
197
+ #hydra_logging: "disabled"
biomap/data.py ADDED
@@ -0,0 +1,584 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from os.path import join
4
+
5
+ import numpy as np
6
+ import torch.multiprocessing
7
+ from PIL import Image
8
+ from scipy.io import loadmat
9
+ from torch.utils.data import DataLoader
10
+ from torch.utils.data import Dataset
11
+ from torchvision.datasets.cityscapes import Cityscapes
12
+ from torchvision.transforms.functional import to_pil_image
13
+ from tqdm import tqdm
14
+
15
+
16
+ def bit_get(val, idx):
17
+ """Gets the bit value.
18
+ Args:
19
+ val: Input value, int or numpy int array.
20
+ idx: Which bit of the input val.
21
+ Returns:
22
+ The "idx"-th bit of input val.
23
+ """
24
+ return (val >> idx) & 1
25
+
26
+
27
+ def create_pascal_label_colormap():
28
+ """Creates a label colormap used in PASCAL VOC segmentation benchmark.
29
+ Returns:
30
+ A colormap for visualizing segmentation results.
31
+ """
32
+ colormap = np.zeros((512, 3), dtype=int)
33
+ ind = np.arange(512, dtype=int)
34
+
35
+ for shift in reversed(list(range(8))):
36
+ for channel in range(3):
37
+ colormap[:, channel] |= bit_get(ind, channel) << shift
38
+ ind >>= 3
39
+
40
+ return colormap
41
+
42
+
43
+ def create_cityscapes_colormap():
44
+ colors = [(128, 64, 128),
45
+ (244, 35, 232),
46
+ (250, 170, 160),
47
+ (230, 150, 140),
48
+ (70, 70, 70),
49
+ (102, 102, 156),
50
+ (190, 153, 153),
51
+ (180, 165, 180),
52
+ (150, 100, 100),
53
+ (150, 120, 90),
54
+ (153, 153, 153),
55
+ (153, 153, 153),
56
+ (250, 170, 30),
57
+ (220, 220, 0),
58
+ (107, 142, 35),
59
+ (152, 251, 152),
60
+ (70, 130, 180),
61
+ (220, 20, 60),
62
+ (255, 0, 0),
63
+ (0, 0, 142),
64
+ (0, 0, 70),
65
+ (0, 60, 100),
66
+ (0, 0, 90),
67
+ (0, 0, 110),
68
+ (0, 80, 100),
69
+ (0, 0, 230),
70
+ (119, 11, 32),
71
+ (0, 0, 0)]
72
+ return np.array(colors)
73
+
74
+
75
+ class DirectoryDataset(Dataset):
76
+ def __init__(self, root, path, image_set, transform, target_transform):
77
+ super(DirectoryDataset, self).__init__()
78
+ self.split = image_set
79
+ self.dir = join(root, path)
80
+ self.img_dir = join(self.dir, "imgs", self.split)
81
+ self.label_dir = join(self.dir, "labels", self.split)
82
+
83
+ self.transform = transform
84
+ self.target_transform = target_transform
85
+
86
+ self.img_files = np.array(sorted(os.listdir(self.img_dir)))
87
+ assert len(self.img_files) > 0
88
+ if os.path.exists(join(self.dir, "labels")):
89
+ self.label_files = np.array(sorted(os.listdir(self.label_dir)))
90
+ assert len(self.img_files) == len(self.label_files)
91
+ else:
92
+ self.label_files = None
93
+ self.fine_to_coarse = {0: 0,
94
+ 1: 1,
95
+ 2: 2,
96
+ 3: 3,
97
+ 4: 4,
98
+ 5: 5,
99
+ 6: 6,
100
+ 7: -1,
101
+ }
102
+
103
+ def __getitem__(self, index):
104
+ image_fn = self.img_files[index]
105
+ img = Image.open(join(self.img_dir, image_fn))
106
+
107
+ if self.label_files is not None:
108
+ label_fn = self.label_files[index]
109
+ label = Image.open(join(self.label_dir, label_fn))
110
+
111
+ seed = np.random.randint(2147483647)
112
+ random.seed(seed)
113
+ torch.manual_seed(seed)
114
+ img = self.transform(img)
115
+
116
+ if self.label_files is not None:
117
+ random.seed(seed)
118
+ torch.manual_seed(seed)
119
+ label = self.target_transform(label)
120
+ new_label_map = torch.zeros_like(label)
121
+ for fine, coarse in self.fine_to_coarse.items():
122
+ new_label_map[label == fine] = coarse
123
+ label = new_label_map
124
+ else:
125
+ label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) - 1
126
+
127
+ mask = (label > 0).to(torch.float32)
128
+ return img, label, mask
129
+
130
+
131
+ def __len__(self):
132
+ return len(self.img_files)
133
+
134
+
135
+ class Potsdam(Dataset):
136
+ def __init__(self, root, image_set, transform, target_transform, coarse_labels):
137
+ super(Potsdam, self).__init__()
138
+ self.split = image_set
139
+ self.root = os.path.join(root, "potsdam")
140
+ self.transform = transform
141
+ self.target_transform = target_transform
142
+ split_files = {
143
+ "train": ["labelled_train.txt"],
144
+ "unlabelled_train": ["unlabelled_train.txt"],
145
+ # "train": ["unlabelled_train.txt"],
146
+ "val": ["labelled_test.txt"],
147
+ "train+val": ["labelled_train.txt", "labelled_test.txt"],
148
+ "all": ["all.txt"]
149
+ }
150
+ assert self.split in split_files.keys()
151
+
152
+ self.files = []
153
+ for split_file in split_files[self.split]:
154
+ with open(join(self.root, split_file), "r") as f:
155
+ self.files.extend(fn.rstrip() for fn in f.readlines())
156
+
157
+ self.coarse_labels = coarse_labels
158
+ self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
159
+ 1: 1, 5: 1, # buildings and clutter
160
+ 2: 2, 3: 2, # vegetation and trees
161
+ 255: -1
162
+ }
163
+
164
+ def __getitem__(self, index):
165
+ image_id = self.files[index]
166
+ img = loadmat(join(self.root, "imgs", image_id + ".mat"))["img"]
167
+ img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
168
+ try:
169
+ label = loadmat(join(self.root, "gt", image_id + ".mat"))["gt"]
170
+ label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
171
+ except FileNotFoundError:
172
+ label = to_pil_image(torch.ones(1, img.height, img.width))
173
+
174
+ seed = np.random.randint(2147483647)
175
+ random.seed(seed)
176
+ torch.manual_seed(seed)
177
+ img = self.transform(img)
178
+
179
+ random.seed(seed)
180
+ torch.manual_seed(seed)
181
+ label = self.target_transform(label).squeeze(0)
182
+ if self.coarse_labels:
183
+ new_label_map = torch.zeros_like(label)
184
+ for fine, coarse in self.fine_to_coarse.items():
185
+ new_label_map[label == fine] = coarse
186
+ label = new_label_map
187
+
188
+ mask = (label > 0).to(torch.float32)
189
+ return img, label, mask
190
+
191
+ def __len__(self):
192
+ return len(self.files)
193
+
194
+
195
+ class PotsdamRaw(Dataset):
196
+ def __init__(self, root, image_set, transform, target_transform, coarse_labels):
197
+ super(PotsdamRaw, self).__init__()
198
+ self.split = image_set
199
+ self.root = os.path.join(root, "potsdamraw", "processed")
200
+ self.transform = transform
201
+ self.target_transform = target_transform
202
+ self.files = []
203
+ for im_num in range(38):
204
+ for i_h in range(15):
205
+ for i_w in range(15):
206
+ self.files.append("{}_{}_{}.mat".format(im_num, i_h, i_w))
207
+
208
+ self.coarse_labels = coarse_labels
209
+ self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
210
+ 1: 1, 5: 1, # buildings and clutter
211
+ 2: 2, 3: 2, # vegetation and trees
212
+ 255: -1
213
+ }
214
+
215
+ def __getitem__(self, index):
216
+ image_id = self.files[index]
217
+ img = loadmat(join(self.root, "imgs", image_id))["img"]
218
+ img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
219
+ try:
220
+ label = loadmat(join(self.root, "gt", image_id))["gt"]
221
+ label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
222
+ except FileNotFoundError:
223
+ label = to_pil_image(torch.ones(1, img.height, img.width))
224
+
225
+ seed = np.random.randint(2147483647)
226
+ random.seed(seed)
227
+ torch.manual_seed(seed)
228
+ img = self.transform(img)
229
+
230
+ random.seed(seed)
231
+ torch.manual_seed(seed)
232
+ label = self.target_transform(label).squeeze(0)
233
+ if self.coarse_labels:
234
+ new_label_map = torch.zeros_like(label)
235
+ for fine, coarse in self.fine_to_coarse.items():
236
+ new_label_map[label == fine] = coarse
237
+ label = new_label_map
238
+
239
+ mask = (label > 0).to(torch.float32)
240
+ return img, label, mask
241
+
242
+ def __len__(self):
243
+ return len(self.files)
244
+
245
+
246
+ class Coco(Dataset):
247
+ def __init__(self, root, image_set, transform, target_transform,
248
+ coarse_labels, exclude_things, subset=None):
249
+ super(Coco, self).__init__()
250
+ self.split = image_set
251
+ self.root = join(root, "cocostuff")
252
+ self.coarse_labels = coarse_labels
253
+ self.transform = transform
254
+ self.label_transform = target_transform
255
+ self.subset = subset
256
+ self.exclude_things = exclude_things
257
+
258
+ if self.subset is None:
259
+ self.image_list = "Coco164kFull_Stuff_Coarse.txt"
260
+ elif self.subset == 6: # IIC Coarse
261
+ self.image_list = "Coco164kFew_Stuff_6.txt"
262
+ elif self.subset == 7: # IIC Fine
263
+ self.image_list = "Coco164kFull_Stuff_Coarse_7.txt"
264
+
265
+ assert self.split in ["train", "val", "train+val"]
266
+ split_dirs = {
267
+ "train": ["train2017"],
268
+ "val": ["val2017"],
269
+ "train+val": ["train2017", "val2017"]
270
+ }
271
+
272
+ self.image_files = []
273
+ self.label_files = []
274
+ for split_dir in split_dirs[self.split]:
275
+ with open(join(self.root, "curated", split_dir, self.image_list), "r") as f:
276
+ img_ids = [fn.rstrip() for fn in f.readlines()]
277
+ for img_id in img_ids:
278
+ self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg"))
279
+ self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png"))
280
+
281
+ self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8,
282
+ 13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7,
283
+ 25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10,
284
+ 37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5,
285
+ 49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2,
286
+ 61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0,
287
+ 73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4,
288
+ 85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22,
289
+ 97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15,
290
+ 107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13,
291
+ 117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24,
292
+ 127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17,
293
+ 137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21,
294
+ 147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23,
295
+ 157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17,
296
+ 167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18,
297
+ 177: 26, 178: 26, 179: 19, 180: 19, 181: 24}
298
+
299
+ self._label_names = [
300
+ "ground-stuff",
301
+ "plant-stuff",
302
+ "sky-stuff",
303
+ ]
304
+ self.cocostuff3_coarse_classes = [23, 22, 21]
305
+ self.first_stuff_index = 12
306
+
307
+ def __getitem__(self, index):
308
+ image_path = self.image_files[index]
309
+ label_path = self.label_files[index]
310
+ seed = np.random.randint(2147483647)
311
+ random.seed(seed)
312
+ torch.manual_seed(seed)
313
+ img = self.transform(Image.open(image_path).convert("RGB"))
314
+
315
+ random.seed(seed)
316
+ torch.manual_seed(seed)
317
+ label = self.label_transform(Image.open(label_path)).squeeze(0)
318
+ label[label == 255] = -1 # to be consistent with 10k
319
+ coarse_label = torch.zeros_like(label)
320
+ for fine, coarse in self.fine_to_coarse.items():
321
+ coarse_label[label == fine] = coarse
322
+ coarse_label[label == -1] = -1
323
+
324
+ if self.coarse_labels:
325
+ coarser_labels = -torch.ones_like(label)
326
+ for i, c in enumerate(self.cocostuff3_coarse_classes):
327
+ coarser_labels[coarse_label == c] = i
328
+ return img, coarser_labels, coarser_labels >= 0
329
+ else:
330
+ if self.exclude_things:
331
+ return img, coarse_label - self.first_stuff_index, (coarse_label >= self.first_stuff_index)
332
+ else:
333
+ return img, coarse_label, coarse_label >= 0
334
+
335
+ def __len__(self):
336
+ return len(self.image_files)
337
+
338
+
339
+ class CityscapesSeg(Dataset):
340
+ def __init__(self, root, image_set, transform, target_transform):
341
+ super(CityscapesSeg, self).__init__()
342
+ self.split = image_set
343
+ self.root = join(root, "cityscapes")
344
+ if image_set == "train":
345
+ # our_image_set = "train_extra"
346
+ # mode = "coarse"
347
+ our_image_set = "train"
348
+ mode = "fine"
349
+ else:
350
+ our_image_set = image_set
351
+ mode = "fine"
352
+ self.inner_loader = Cityscapes(self.root, our_image_set,
353
+ mode=mode,
354
+ target_type="semantic",
355
+ transform=None,
356
+ target_transform=None)
357
+ self.transform = transform
358
+ self.target_transform = target_transform
359
+ self.first_nonvoid = 7
360
+
361
+ def __getitem__(self, index):
362
+ if self.transform is not None:
363
+ image, target = self.inner_loader[index]
364
+
365
+ seed = np.random.randint(2147483647)
366
+ random.seed(seed)
367
+ torch.manual_seed(seed)
368
+ image = self.transform(image)
369
+ random.seed(seed)
370
+ torch.manual_seed(seed)
371
+ target = self.target_transform(target)
372
+
373
+ target = target - self.first_nonvoid
374
+ target[target < 0] = -1
375
+ mask = target == -1
376
+ return image, target.squeeze(0), mask
377
+ else:
378
+ return self.inner_loader[index]
379
+
380
+ def __len__(self):
381
+ return len(self.inner_loader)
382
+
383
+
384
+ class CroppedDataset(Dataset):
385
+ def __init__(self, root, dataset_name, crop_type, crop_ratio, image_set, transform, target_transform):
386
+ super(CroppedDataset, self).__init__()
387
+ self.dataset_name = dataset_name
388
+ self.split = image_set
389
+ self.root = join(root, "cropped", "{}_{}_crop_{}".format(dataset_name, crop_type, crop_ratio))
390
+ self.transform = transform
391
+ self.target_transform = target_transform
392
+ self.img_dir = join(self.root, "img", self.split)
393
+ self.label_dir = join(self.root, "label", self.split)
394
+ self.num_images = len(os.listdir(self.img_dir))
395
+ assert self.num_images == len(os.listdir(self.label_dir))
396
+
397
+ def __getitem__(self, index):
398
+ image = Image.open(join(self.img_dir, "{}.jpg".format(index))).convert('RGB')
399
+ target = Image.open(join(self.label_dir, "{}.png".format(index)))
400
+
401
+ seed = np.random.randint(2147483647)
402
+ random.seed(seed)
403
+ torch.manual_seed(seed)
404
+ image = self.transform(image)
405
+ random.seed(seed)
406
+ torch.manual_seed(seed)
407
+ target = self.target_transform(target)
408
+
409
+ target = target - 1
410
+ mask = target == -1
411
+ return image, target.squeeze(0), mask
412
+
413
+ def __len__(self):
414
+ return self.num_images
415
+
416
+
417
+ class MaterializedDataset(Dataset):
418
+
419
+ def __init__(self, ds):
420
+ self.ds = ds
421
+ self.materialized = []
422
+ loader = DataLoader(ds, num_workers=12, collate_fn=lambda l: l[0])
423
+ for batch in tqdm(loader):
424
+ self.materialized.append(batch)
425
+
426
+ def __len__(self):
427
+ return len(self.ds)
428
+
429
+ def __getitem__(self, ind):
430
+ return self.materialized[ind]
431
+
432
+
433
+ class ContrastiveSegDataset(Dataset):
434
+ def __init__(self,
435
+ pytorch_data_dir,
436
+ dataset_name,
437
+ crop_type,
438
+ image_set,
439
+ transform,
440
+ target_transform,
441
+ cfg,
442
+ aug_geometric_transform=None,
443
+ aug_photometric_transform=None,
444
+ num_neighbors=5,
445
+ compute_knns=False,
446
+ mask=False,
447
+ pos_labels=False,
448
+ pos_images=False,
449
+ extra_transform=None,
450
+ model_type_override=None
451
+ ):
452
+ super(ContrastiveSegDataset).__init__()
453
+ self.num_neighbors = num_neighbors
454
+ self.image_set = image_set
455
+ self.dataset_name = dataset_name
456
+ self.mask = mask
457
+ self.pos_labels = pos_labels
458
+ self.pos_images = pos_images
459
+ self.extra_transform = extra_transform
460
+
461
+ if dataset_name == "potsdam":
462
+ self.n_classes = 3
463
+ dataset_class = Potsdam
464
+ extra_args = dict(coarse_labels=True)
465
+ elif dataset_name == "potsdamraw":
466
+ self.n_classes = 3
467
+ dataset_class = PotsdamRaw
468
+ extra_args = dict(coarse_labels=True)
469
+ elif dataset_name == "directory":
470
+ self.n_classes = cfg.dir_dataset_n_classes
471
+ dataset_class = DirectoryDataset
472
+ extra_args = dict(path=cfg.dir_dataset_name)
473
+ elif dataset_name == "cityscapes" and crop_type is None:
474
+ self.n_classes = 27
475
+ dataset_class = CityscapesSeg
476
+ extra_args = dict()
477
+ elif dataset_name == "cityscapes" and crop_type is not None:
478
+ self.n_classes = 27
479
+ dataset_class = CroppedDataset
480
+ extra_args = dict(dataset_name="cityscapes", crop_type=crop_type, crop_ratio=cfg.crop_ratio)
481
+ elif dataset_name == "cocostuff3":
482
+ self.n_classes = 3
483
+ dataset_class = Coco
484
+ extra_args = dict(coarse_labels=True, subset=6, exclude_things=True)
485
+ elif dataset_name == "cocostuff15":
486
+ self.n_classes = 15
487
+ dataset_class = Coco
488
+ extra_args = dict(coarse_labels=False, subset=7, exclude_things=True)
489
+ elif dataset_name == "cocostuff27" and crop_type is not None:
490
+ self.n_classes = 27
491
+ dataset_class = CroppedDataset
492
+ extra_args = dict(dataset_name="cocostuff27", crop_type=cfg.crop_type, crop_ratio=cfg.crop_ratio)
493
+ elif dataset_name == "cocostuff27" and crop_type is None:
494
+ self.n_classes = 27
495
+ dataset_class = Coco
496
+ extra_args = dict(coarse_labels=False, subset=None, exclude_things=False)
497
+ if image_set == "val":
498
+ extra_args["subset"] = 7
499
+ else:
500
+ raise ValueError("Unknown dataset: {}".format(dataset_name))
501
+
502
+ self.aug_geometric_transform = aug_geometric_transform
503
+ self.aug_photometric_transform = aug_photometric_transform
504
+
505
+ self.dataset = dataset_class(
506
+ root=pytorch_data_dir,
507
+ image_set=self.image_set,
508
+ transform=transform,
509
+ target_transform=target_transform, **extra_args)
510
+
511
+ if model_type_override is not None:
512
+ model_type = model_type_override
513
+ else:
514
+ model_type = cfg.model_type
515
+
516
+ nice_dataset_name = cfg.dir_dataset_name if dataset_name == "directory" else dataset_name
517
+ feature_cache_file = join(pytorch_data_dir, "nns", "nns_{}_{}_{}_{}_{}.npz".format(
518
+ model_type, nice_dataset_name, image_set, crop_type, cfg.res))
519
+ if pos_labels or pos_images:
520
+ if not os.path.exists(feature_cache_file) or compute_knns:
521
+ raise ValueError("could not find nn file {} please run precompute_knns".format(feature_cache_file))
522
+ else:
523
+ loaded = np.load(feature_cache_file)
524
+ self.nns = loaded["nns"]
525
+ assert len(self.dataset) == self.nns.shape[0]
526
+
527
+ def __len__(self):
528
+ return len(self.dataset)
529
+
530
+ def _set_seed(self, seed):
531
+ random.seed(seed) # apply this seed to img tranfsorms
532
+ torch.manual_seed(seed) # needed for torchvision 0.7
533
+
534
+ def __getitem__(self, ind):
535
+ pack = self.dataset[ind]
536
+
537
+ if self.pos_images or self.pos_labels:
538
+ ind_pos = self.nns[ind][torch.randint(low=1, high=self.num_neighbors + 1, size=[]).item()]
539
+ pack_pos = self.dataset[ind_pos]
540
+
541
+ seed = np.random.randint(2147483647) # make a seed with numpy generator
542
+
543
+ self._set_seed(seed)
544
+ coord_entries = torch.meshgrid([torch.linspace(-1, 1, pack[0].shape[1]),
545
+ torch.linspace(-1, 1, pack[0].shape[2])])
546
+ coord = torch.cat([t.unsqueeze(0) for t in coord_entries], 0)
547
+
548
+ if self.extra_transform is not None:
549
+ extra_trans = self.extra_transform
550
+ else:
551
+ extra_trans = lambda i, x: x
552
+
553
+ def squeeze_tuple(label_raw):
554
+ if type(label_raw) == tuple:
555
+ return tuple(x.squeeze() for x in label_raw)
556
+ else:
557
+ return label_raw.squeeze()
558
+ ret = {
559
+ "ind": ind,
560
+ "img": extra_trans(ind, pack[0]),
561
+ "label": squeeze_tuple(extra_trans(ind, pack[1]))
562
+ }
563
+
564
+ if self.pos_images:
565
+ ret["img_pos"] = extra_trans(ind, pack_pos[0])
566
+ ret["ind_pos"] = ind_pos
567
+
568
+ if self.mask:
569
+ ret["mask"] = pack[2]
570
+
571
+ if self.pos_labels:
572
+ ret["label_pos"] = squeeze_tuple(extra_trans(ind, pack_pos[1]))
573
+ ret["mask_pos"] = pack_pos[2]
574
+
575
+ if self.aug_photometric_transform is not None:
576
+ img_aug = self.aug_photometric_transform(self.aug_geometric_transform(pack[0]))
577
+
578
+ self._set_seed(seed)
579
+ coord_aug = self.aug_geometric_transform(coord)
580
+
581
+ ret["img_aug"] = img_aug
582
+ ret["coord_aug"] = coord_aug.permute(1, 2, 0)
583
+
584
+ return ret
biomap/dataset_generator/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .data_loader import DataLoader
2
+
3
+
4
+ __all__ = [
5
+ 'DataLoader',
6
+ ]
biomap/dataset_generator/data_loader.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import ee
3
+ from func_timeout import func_set_timeout
4
+ import pandas as pd
5
+ from PIL import Image
6
+ import requests
7
+ import tempfile
8
+ import io
9
+ from tqdm import tqdm
10
+ import functools
11
+ import re # Used in an eval statement
12
+ from typing import List
13
+ from typing import Union
14
+ from typing import Any
15
+
16
+
17
+ class DataLoader:
18
+ """
19
+ Main class for loading and exploring data from satellite images.
20
+ The goal is to load an ImageCollection and to filter that collection according to needs, with methods like
21
+ filter, filterDate, filterBounds, select. These will work just like earth engine's methods with the same names.
22
+
23
+ This class, just like earth engine, works with lazy loading and compute. This means that running filterBounds
24
+ will not actually filter the image collection until required, e.g. when counting the images by accessing .count
25
+ property.
26
+ However, it will only load once the information it needs, unless additional filtering is made.
27
+
28
+ This works thanks to the signal_change decorator. If you develop a new filtering method for this class,
29
+ you will need to decorate your method with @signal_change.
30
+ In addition, if you develop a new method that will require to run getInfo to actually load data from
31
+ Google Earth Engine, you will need to use _get_timeout_info(your object before getInfo). This will run
32
+ getInfo with a timeout (currently set to 10 seconds).
33
+ It is important to use a timeout to avoid unexpected run times.
34
+
35
+ Usage:
36
+ >>> dl = DataLoader(satellite_name="COPERNICUS/S2_SR", \
37
+ start_date='2021-01-01', \
38
+ end_date='2021-01-15', \
39
+ bands=["TCI_R", "TCI_G", "TCI_B"], \
40
+ geographic_bounds=ee.Geometry.Point(*[5.238728194366604, 44.474864056855935]).buffer(500) \
41
+ )
42
+
43
+ Get a pandas dataframe with all pixel values as a timeseries:
44
+ >>> dl.getRegion(dl.bounds, 500)
45
+ >>> dl.region.head(2)
46
+ [Out]
47
+ id longitude latitude time B1 B2 B3 B4 B5 B6 ... WVP SCL TCI_R TCI_G TCI_B MSK_CLDPRB MSK_SNWPRB QA10 QA20 QA60
48
+ 0 20210102T104441_20210102T104435_T31TFK 5.234932 44.473344 2021-01-02 10:48:36.299 6297 5955 5768 5773 5965 5883 ... 393 8 255 255 255 0 95 0 0 1024
49
+ 1 20210104T103329_20210104T103331_T31TFK 5.234932 44.473344 2021-01-04 10:38:38.304 5547 5355 5184 5090 5254 5229 ... 314 9 255 255 255 29 9 0 0 1024
50
+
51
+ >>> dl.date_range
52
+ [Out]
53
+ {'max': datetime.datetime(2021, 1, 14, 11, 38, 39, 208000),
54
+ 'min': datetime.datetime(2021, 1, 2, 11, 48, 36, 299000)}
55
+
56
+ >>> dl.count
57
+ [Out]
58
+ 6
59
+
60
+ >>> dl.collection_info # constains a html description of the dataset in "description"
61
+
62
+ >>> dl.image_ids
63
+ [Out]
64
+ ['COPERNICUS/S2_SR/20210102T104441_20210102T104435_T31TFK',
65
+ 'COPERNICUS/S2_SR/20210104T103329_20210104T103331_T31TFK',
66
+ 'COPERNICUS/S2_SR/20210107T104329_20210107T104328_T31TFK',
67
+ 'COPERNICUS/S2_SR/20210109T103421_20210109T103431_T31TFK',
68
+ 'COPERNICUS/S2_SR/20210112T104411_20210112T104438_T31TFK',
69
+ 'COPERNICUS/S2_SR/20210114T103309_20210114T103305_T31TFK']
70
+
71
+ # Download the image
72
+ >>> img = dl.download_image(dl.image_ids[3])
73
+
74
+ # Download all images as a list
75
+ >>> imgs = dl.download_all_images(scale=1)
76
+
77
+ """
78
+ def __init__(self,
79
+ satellite_name: str,
80
+ bands: Union[List, str] = None,
81
+ start_date: str = None,
82
+ end_date: str = None,
83
+ geographic_bounds: ee.geometry = None,
84
+ scale: int = 10,
85
+ crs: str = "EPSG:32630"
86
+ ):
87
+ """
88
+
89
+ Args:
90
+ satellite_name: satellite to use. Examples: COPERNICUS/S2_SR, COPERNICUS/CORINE/V20/100m.
91
+ See https://developers.google.com/earth-engine/datasets for the full list.
92
+ bands: list of bands to load.
93
+ start_date: lowest possible date. Might be lower than the actual date of the first picture.
94
+ end_date: Latest possible date.
95
+ geographic_bounds: Region of interest.
96
+ """
97
+ self.satellite_name = satellite_name
98
+ if isinstance(bands, str):
99
+ bands = [bands]
100
+ self.bands = bands if bands is not None else list()
101
+ if start_date is None or end_date is None:
102
+ assert (start_date is not None) and (end_date is not None), "start_date and end_date must both be provided"
103
+ self.start_date = start_date
104
+ self.end_date = end_date
105
+ self.bounds = geographic_bounds
106
+
107
+ # Lazy computed
108
+ self._available_images = None
109
+
110
+ # Start getting info from google cloud
111
+ if satellite_name:
112
+ self.image_collection = ee.ImageCollection(self.satellite_name)
113
+ if self.bounds:
114
+ self.filterBounds(self.bounds)
115
+ if self.start_date is not None:
116
+ self.filterDate(self.start_date, self.end_date)
117
+ self.scale = scale
118
+ self.crs = crs
119
+ self.image_list = None
120
+ self._df_image_list = None
121
+ self.image_collection_info = None
122
+ self._date_range = None
123
+ self.date_filter_change = False
124
+ self._count = None
125
+
126
+ # Bool for caching
127
+ self.filter_change = True
128
+ self._describe = None
129
+
130
+ def signal_change(func):
131
+ """Signals that additional filtering was performed. To be used
132
+ as a decorator."""
133
+ @functools.wraps(func)
134
+ def wrap(self, *args, **kwargs):
135
+ self.filter_change = True
136
+ self.date_filter_change = True
137
+ return func(self, *args, **kwargs)
138
+ return wrap
139
+
140
+ @staticmethod
141
+ @func_set_timeout(10)
142
+ def _get_timeout_info(instance: Any):
143
+ """Runs getInfo on anything that is passed, with a timeout."""
144
+ return instance.getInfo()
145
+
146
+ @staticmethod
147
+ def _authenticate_gee():
148
+ """Authenticates earth engine if needed, and initializes."""
149
+ try:
150
+ ee.Initialize()
151
+ except Exception as e:
152
+ # Trigger the authentication flow.
153
+ ee.Authenticate()
154
+ # Initialize the library.
155
+ ee.Initialize()
156
+
157
+ def filter(self, ee_filter: ee.Filter):
158
+ """Applies a filter to the image_collection attribute. This can be useful for example
159
+ to filter out clouds
160
+
161
+ Args:
162
+ ee_filter: Filter to apply, must be an instance of ee.Filter.
163
+
164
+ Returns: self, for operation chaining as possible with the earth engine API.
165
+
166
+ """
167
+ self.image_collection = self.image_collection.filter(ee_filter)
168
+
169
+ return self
170
+
171
+ @property
172
+ def count(self):
173
+ """Number of images in the ImageCollection"""
174
+ if self.filter_change or self._count is None:
175
+ self._count = self._get_timeout_info(self.image_collection.size())
176
+ self.filter_change = False
177
+ return self._count
178
+
179
+ @property
180
+ def available_images(self):
181
+ """Gets the ImageCollection info"""
182
+ if self.filter_change or self._available_images is None:
183
+ self._available_images = self._get_timeout_info(self.image_collection)
184
+ return self._available_images
185
+
186
+ @signal_change
187
+ def filterDate(self, *args, **kwargs):
188
+ """Wrapper for the filterDate method in earth engine on the ImageCollection"""
189
+ self.image_collection = self.image_collection.filterDate(*args, **kwargs)
190
+ return self
191
+
192
+ @signal_change
193
+ def getRegion(self, *args, **kwargs):
194
+ """Wrapper for the getRegion method in earth engine on the ImageCollection.
195
+ Caveat! getRegion does not return an image collection, so the image_list attribute gets
196
+ updated instead of the image_collection attribute. However, the instance of the DataLoader class
197
+ is still returned, so this could be chained with another method on ImageCollection, which wouldn't be
198
+ possible using earth engine.
199
+ """
200
+ self.image_list = self.image_collection.getRegion(*args, **kwargs)
201
+ return self
202
+
203
+ @signal_change
204
+ def filterBounds(self, geometry, *args, **kwargs):
205
+ """Wrapper for the filterBounds method in earth engine on the ImageCollection"""
206
+ self.image_collection = self.image_collection.filterBounds(geometry, *args, **kwargs)
207
+ self.bounds = geometry
208
+ return self
209
+
210
+ @signal_change
211
+ def select(self, *bands, **kwargs):
212
+ """Wrapper for the select method in earth engine on the ImageCollection"""
213
+ self.image_collection = self.image_collection.select(*bands, **kwargs)
214
+ self.bands = list(set(self.bands) | set(bands)) # Unique bands
215
+ return self
216
+
217
+ @property
218
+ def date_range(self):
219
+ """Gets the actual date range of the images in the image collection."""
220
+ if self.date_filter_change or self._date_range is None:
221
+ date_range = self.image_collection.reduceColumns(ee.Reducer.minMax(), ["system:time_start"]).getInfo()
222
+ self._date_range = {key: datetime.fromtimestamp(value/1e3) for key, value in date_range.items()}
223
+ self.date_filter_change = False
224
+ return self._date_range
225
+
226
+ @property
227
+ def region(self):
228
+ """Gets a time series as a pandas DataFrame of the band values for the specified region."""
229
+ if self.filter_change:
230
+ if self.image_list is None:
231
+ self.getRegion()
232
+ res_list = self._get_timeout_info(self.image_list)
233
+ df = pd.DataFrame(res_list[1:], columns=res_list[0])
234
+ df.loc[:, "time"] = pd.to_datetime(df.loc[:, "time"], unit="ms")
235
+ self._df_image_list = df
236
+ self.filter_change = False
237
+ return self._df_image_list
238
+
239
+ @property
240
+ def collection_info(self):
241
+ """Runs getInfo on the image collection (the first time the next time the previously
242
+ populated attribute will be returned)."""
243
+ if self.count > 5000:
244
+ raise Exception("Too many images to load. Try filtering more")
245
+ if self.filter_change or self.image_collection_info is None:
246
+ self.image_collection_info = self._get_timeout_info(self.image_collection)
247
+ return self.image_collection_info
248
+
249
+ @property
250
+ def image_ids(self):
251
+ """list of names of available images in the image collection"""
252
+ return [i["id"] for i in self.collection_info["features"]]
253
+
254
+ def __repr__(self):
255
+ try:
256
+ return f"""
257
+ Size: {self.count}
258
+
259
+ Dataset date ranges:
260
+ From: {self.date_range["min"]}
261
+ To: {self.date_range["max"]}
262
+
263
+ Selected bands:
264
+ {self.bands}
265
+
266
+ """
267
+ except Exception as e:
268
+ raise Exception("Impossible to represent the dataset. Try filtering more. Error handling to do.")
269
+
270
+ def reproject(self, image, **kwargs):
271
+ def resolve(name: str):
272
+ # Resolve crs
273
+ if name in kwargs:
274
+ item = kwargs[name]
275
+ elif getattr(self, name):
276
+ item = getattr(self, name)
277
+ else:
278
+ item = None
279
+ return item
280
+ crs = resolve("crs")
281
+ scale = resolve("scale")
282
+ if crs is not None or scale is not None:
283
+ image = image.reproject(crs, None, scale)
284
+ return image
285
+
286
+ def download_image(self, image_id: str, **kwargs):
287
+ """Downloads an image based on its id / name. The additional arguments are passed
288
+ to getThumbUrl, and could be scale, max, min...
289
+ """
290
+ img = ee.Image(image_id).select(*self.bands)
291
+ img = self.reproject(img, **kwargs)
292
+ input_args = {'region': self.bounds}
293
+ input_args.update(**kwargs)
294
+ all_bands = self.collection_info["features"][0]["bands"]
295
+ selected_bands = [band for i, band in enumerate(all_bands) if all_bands[i]["id"] in self.bands]
296
+ if "min" not in input_args:
297
+ input_args.update({"min": selected_bands[0]["data_type"]["min"]})
298
+ if "max" not in input_args:
299
+ input_args.update({"max": selected_bands[0]["data_type"]["max"]})
300
+ url = img.getThumbUrl(input_args)
301
+ buffer = tempfile.SpooledTemporaryFile(max_size=1e9)
302
+ r = requests.get(url, stream=True)
303
+ if r.status_code == 200:
304
+ downloaded = 0
305
+ # filesize = int(r.headers['content-length'])
306
+ for chunk in r.iter_content(chunk_size=1024):
307
+ downloaded += len(chunk)
308
+ buffer.write(chunk)
309
+ buffer.seek(0)
310
+ img = Image.open(io.BytesIO(buffer.read()))
311
+ buffer.close()
312
+ return img
313
+
314
+ @staticmethod
315
+ def _regex(regex: str, im_id_list: List[str], include: bool) -> list:
316
+ """
317
+ Filters the im_id_list based on a regular expression. This is useful before downloading
318
+ a collection of images. For example, using (.*)TXT with include=True will only download images
319
+ that end with TXT, wich for Nantes means filtering out empty or half empty images.
320
+ Args:
321
+ regex: python regex as a strng
322
+ im_id_list: list, image id list
323
+ include: whether to include or exclude elements that match the regex.
324
+
325
+ Returns: filtered list.
326
+
327
+ """
328
+ expression = "re.match('{regex}', '{im_id}') is not None"
329
+ if not include:
330
+ expression = "not " + expression
331
+ filtered_list = list()
332
+ for im_id in im_id_list:
333
+ if eval(expression.format(regex=regex, im_id=im_id)):
334
+ filtered_list.append(im_id)
335
+ return filtered_list
336
+
337
+ def download_all_images(self, regex_exclude: str = None, regex_include: str = None, **kwargs):
338
+ """
339
+ Runs download_image in a for loop around the available images.
340
+ Makes it possible to filter images to download based on a regex.
341
+ Args:
342
+ regex_exclude: any image that matches this regex will be excluded.
343
+ regex_include: any image that matches this regex will be included
344
+ **kwargs: arguments to be passed to getThumbUrl
345
+
346
+ Returns: list of PIL images
347
+ """
348
+ images = list()
349
+ image_ids = self.image_ids
350
+ if regex_exclude is not None:
351
+ image_ids = self._regex(regex_exclude, image_ids, include=False)
352
+ if regex_include is not None:
353
+ image_ids = self._regex(regex_include, image_ids, include=True)
354
+ for i in tqdm(range(len(image_ids))):
355
+ images.append(self.download_image(image_ids[i], **kwargs))
356
+ return images
biomap/dino/utils.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Misc functions.
16
+
17
+ Mostly copy-paste from torchvision references or other public repos like DETR:
18
+ https://github.com/facebookresearch/detr/blob/master/util/misc.py
19
+ """
20
+ import os
21
+ import sys
22
+ import time
23
+ import math
24
+ import random
25
+ import datetime
26
+ import subprocess
27
+ from collections import defaultdict, deque
28
+
29
+ import numpy as np
30
+ import torch
31
+ from torch import nn
32
+ import torch.distributed as dist
33
+ from PIL import ImageFilter, ImageOps
34
+
35
+
36
+ class GaussianBlur(object):
37
+ """
38
+ Apply Gaussian Blur to the PIL image.
39
+ """
40
+ def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
41
+ self.prob = p
42
+ self.radius_min = radius_min
43
+ self.radius_max = radius_max
44
+
45
+ def __call__(self, img):
46
+ do_it = random.random() <= self.prob
47
+ if not do_it:
48
+ return img
49
+
50
+ return img.filter(
51
+ ImageFilter.GaussianBlur(
52
+ radius=random.uniform(self.radius_min, self.radius_max)
53
+ )
54
+ )
55
+
56
+
57
+ class Solarization(object):
58
+ """
59
+ Apply Solarization to the PIL image.
60
+ """
61
+ def __init__(self, p):
62
+ self.p = p
63
+
64
+ def __call__(self, img):
65
+ if random.random() < self.p:
66
+ return ImageOps.solarize(img)
67
+ else:
68
+ return img
69
+
70
+
71
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
72
+ if os.path.isfile(pretrained_weights):
73
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
74
+ if checkpoint_key is not None and checkpoint_key in state_dict:
75
+ print(f"Take key {checkpoint_key} in provided checkpoint dict")
76
+ state_dict = state_dict[checkpoint_key]
77
+ # remove `module.` prefix
78
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
79
+ # remove `backbone.` prefix induced by multicrop wrapper
80
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
81
+ msg = model.load_state_dict(state_dict, strict=False)
82
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
83
+ else:
84
+ print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
85
+ url = None
86
+ if model_name == "vit_small" and patch_size == 16:
87
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
88
+ elif model_name == "vit_small" and patch_size == 8:
89
+ url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
90
+ elif model_name == "vit_base" and patch_size == 16:
91
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
92
+ elif model_name == "vit_base" and patch_size == 8:
93
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
94
+ if url is not None:
95
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
96
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
97
+ model.load_state_dict(state_dict, strict=True)
98
+ else:
99
+ print("There is no reference weights available for this model => We use random weights.")
100
+
101
+
102
+ def clip_gradients(model, clip):
103
+ norms = []
104
+ for name, p in model.named_parameters():
105
+ if p.grad is not None:
106
+ param_norm = p.grad.data.norm(2)
107
+ norms.append(param_norm.item())
108
+ clip_coef = clip / (param_norm + 1e-6)
109
+ if clip_coef < 1:
110
+ p.grad.data.mul_(clip_coef)
111
+ return norms
112
+
113
+
114
+ def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
115
+ if epoch >= freeze_last_layer:
116
+ return
117
+ for n, p in model.named_parameters():
118
+ if "last_layer" in n:
119
+ p.grad = None
120
+
121
+
122
+ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
123
+ """
124
+ Re-start from checkpoint
125
+ """
126
+ if not os.path.isfile(ckp_path):
127
+ return
128
+ print("Found checkpoint at {}".format(ckp_path))
129
+
130
+ # open checkpoint file
131
+ checkpoint = torch.load(ckp_path, map_location="cpu")
132
+
133
+ # key is what to look for in the checkpoint file
134
+ # value is the object to load
135
+ # example: {'state_dict': model}
136
+ for key, value in kwargs.items():
137
+ if key in checkpoint and value is not None:
138
+ try:
139
+ msg = value.load_state_dict(checkpoint[key], strict=False)
140
+ print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
141
+ except TypeError:
142
+ try:
143
+ msg = value.load_state_dict(checkpoint[key])
144
+ print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
145
+ except ValueError:
146
+ print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
147
+ else:
148
+ print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
149
+
150
+ # re load variable important for the run
151
+ if run_variables is not None:
152
+ for var_name in run_variables:
153
+ if var_name in checkpoint:
154
+ run_variables[var_name] = checkpoint[var_name]
155
+
156
+
157
+ def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
158
+ warmup_schedule = np.array([])
159
+ warmup_iters = warmup_epochs * niter_per_ep
160
+ if warmup_epochs > 0:
161
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
162
+
163
+ iters = np.arange(epochs * niter_per_ep - warmup_iters)
164
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
165
+
166
+ schedule = np.concatenate((warmup_schedule, schedule))
167
+ assert len(schedule) == epochs * niter_per_ep
168
+ return schedule
169
+
170
+
171
+ def bool_flag(s):
172
+ """
173
+ Parse boolean arguments from the command line.
174
+ """
175
+ FALSY_STRINGS = {"off", "false", "0"}
176
+ TRUTHY_STRINGS = {"on", "true", "1"}
177
+ if s.lower() in FALSY_STRINGS:
178
+ return False
179
+ elif s.lower() in TRUTHY_STRINGS:
180
+ return True
181
+ else:
182
+ raise argparse.ArgumentTypeError("invalid value for a boolean flag")
183
+
184
+
185
+ def fix_random_seeds(seed=31):
186
+ """
187
+ Fix random seeds.
188
+ """
189
+ torch.manual_seed(seed)
190
+ torch.cuda.manual_seed_all(seed)
191
+ np.random.seed(seed)
192
+
193
+
194
+ class SmoothedValue(object):
195
+ """Track a series of values and provide access to smoothed values over a
196
+ window or the global series average.
197
+ """
198
+
199
+ def __init__(self, window_size=20, fmt=None):
200
+ if fmt is None:
201
+ fmt = "{median:.6f} ({global_avg:.6f})"
202
+ self.deque = deque(maxlen=window_size)
203
+ self.total = 0.0
204
+ self.count = 0
205
+ self.fmt = fmt
206
+
207
+ def update(self, value, n=1):
208
+ self.deque.append(value)
209
+ self.count += n
210
+ self.total += value * n
211
+
212
+ def synchronize_between_processes(self):
213
+ """
214
+ Warning: does not synchronize the deque!
215
+ """
216
+ if not is_dist_avail_and_initialized():
217
+ return
218
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
219
+ dist.barrier()
220
+ dist.all_reduce(t)
221
+ t = t.tolist()
222
+ self.count = int(t[0])
223
+ self.total = t[1]
224
+
225
+ @property
226
+ def median(self):
227
+ d = torch.tensor(list(self.deque))
228
+ return d.median().item()
229
+
230
+ @property
231
+ def avg(self):
232
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
233
+ return d.mean().item()
234
+
235
+ @property
236
+ def global_avg(self):
237
+ return self.total / self.count
238
+
239
+ @property
240
+ def max(self):
241
+ return max(self.deque)
242
+
243
+ @property
244
+ def value(self):
245
+ return self.deque[-1]
246
+
247
+ def __str__(self):
248
+ return self.fmt.format(
249
+ median=self.median,
250
+ avg=self.avg,
251
+ global_avg=self.global_avg,
252
+ max=self.max,
253
+ value=self.value)
254
+
255
+
256
+ def reduce_dict(input_dict, average=True):
257
+ """
258
+ Args:
259
+ input_dict (dict): all the values will be reduced
260
+ average (bool): whether to do average or sum
261
+ Reduce the values in the dictionary from all processes so that all processes
262
+ have the averaged results. Returns a dict with the same fields as
263
+ input_dict, after reduction.
264
+ """
265
+ world_size = get_world_size()
266
+ if world_size < 2:
267
+ return input_dict
268
+ with torch.no_grad():
269
+ names = []
270
+ values = []
271
+ # sort the keys so that they are consistent across processes
272
+ for k in sorted(input_dict.keys()):
273
+ names.append(k)
274
+ values.append(input_dict[k])
275
+ values = torch.stack(values, dim=0)
276
+ dist.all_reduce(values)
277
+ if average:
278
+ values /= world_size
279
+ reduced_dict = {k: v for k, v in zip(names, values)}
280
+ return reduced_dict
281
+
282
+
283
+ class MetricLogger(object):
284
+ def __init__(self, delimiter="\t"):
285
+ self.meters = defaultdict(SmoothedValue)
286
+ self.delimiter = delimiter
287
+
288
+ def update(self, **kwargs):
289
+ for k, v in kwargs.items():
290
+ if isinstance(v, torch.Tensor):
291
+ v = v.item()
292
+ assert isinstance(v, (float, int))
293
+ self.meters[k].update(v)
294
+
295
+ def __getattr__(self, attr):
296
+ if attr in self.meters:
297
+ return self.meters[attr]
298
+ if attr in self.__dict__:
299
+ return self.__dict__[attr]
300
+ raise AttributeError("'{}' object has no attribute '{}'".format(
301
+ type(self).__name__, attr))
302
+
303
+ def __str__(self):
304
+ loss_str = []
305
+ for name, meter in self.meters.items():
306
+ loss_str.append(
307
+ "{}: {}".format(name, str(meter))
308
+ )
309
+ return self.delimiter.join(loss_str)
310
+
311
+ def synchronize_between_processes(self):
312
+ for meter in self.meters.values():
313
+ meter.synchronize_between_processes()
314
+
315
+ def add_meter(self, name, meter):
316
+ self.meters[name] = meter
317
+
318
+ def log_every(self, iterable, print_freq, header=None):
319
+ i = 0
320
+ if not header:
321
+ header = ''
322
+ start_time = time.time()
323
+ end = time.time()
324
+ iter_time = SmoothedValue(fmt='{avg:.6f}')
325
+ data_time = SmoothedValue(fmt='{avg:.6f}')
326
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
327
+ if torch.cuda.is_available():
328
+ log_msg = self.delimiter.join([
329
+ header,
330
+ '[{0' + space_fmt + '}/{1}]',
331
+ 'eta: {eta}',
332
+ '{meters}',
333
+ 'time: {time}',
334
+ 'data: {data}',
335
+ 'max mem: {memory:.0f}'
336
+ ])
337
+ else:
338
+ log_msg = self.delimiter.join([
339
+ header,
340
+ '[{0' + space_fmt + '}/{1}]',
341
+ 'eta: {eta}',
342
+ '{meters}',
343
+ 'time: {time}',
344
+ 'data: {data}'
345
+ ])
346
+ MB = 1024.0 * 1024.0
347
+ for obj in iterable:
348
+ data_time.update(time.time() - end)
349
+ yield obj
350
+ iter_time.update(time.time() - end)
351
+ if i % print_freq == 0 or i == len(iterable) - 1:
352
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
353
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
354
+ if torch.cuda.is_available():
355
+ print(log_msg.format(
356
+ i, len(iterable), eta=eta_string,
357
+ meters=str(self),
358
+ time=str(iter_time), data=str(data_time),
359
+ memory=torch.cuda.max_memory_allocated() / MB))
360
+ else:
361
+ print(log_msg.format(
362
+ i, len(iterable), eta=eta_string,
363
+ meters=str(self),
364
+ time=str(iter_time), data=str(data_time)))
365
+ i += 1
366
+ end = time.time()
367
+ total_time = time.time() - start_time
368
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
369
+ print('{} Total time: {} ({:.6f} s / it)'.format(
370
+ header, total_time_str, total_time / len(iterable)))
371
+
372
+
373
+ def get_sha():
374
+ cwd = os.path.dirname(os.path.abspath(__file__))
375
+
376
+ def _run(command):
377
+ return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
378
+ sha = 'N/A'
379
+ diff = "clean"
380
+ branch = 'N/A'
381
+ try:
382
+ sha = _run(['git', 'rev-parse', 'HEAD'])
383
+ subprocess.check_output(['git', 'diff'], cwd=cwd)
384
+ diff = _run(['git', 'diff-index', 'HEAD'])
385
+ diff = "has uncommited changes" if diff else "clean"
386
+ branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
387
+ except Exception:
388
+ pass
389
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
390
+ return message
391
+
392
+
393
+ def is_dist_avail_and_initialized():
394
+ if not dist.is_available():
395
+ return False
396
+ if not dist.is_initialized():
397
+ return False
398
+ return True
399
+
400
+
401
+ def get_world_size():
402
+ if not is_dist_avail_and_initialized():
403
+ return 1
404
+ return dist.get_world_size()
405
+
406
+
407
+ def get_rank():
408
+ if not is_dist_avail_and_initialized():
409
+ return 0
410
+ return dist.get_rank()
411
+
412
+
413
+ def is_main_process():
414
+ return get_rank() == 0
415
+
416
+
417
+ def save_on_master(*args, **kwargs):
418
+ if is_main_process():
419
+ torch.save(*args, **kwargs)
420
+
421
+
422
+ def setup_for_distributed(is_master):
423
+ """
424
+ This function disables printing when not in master process
425
+ """
426
+ import builtins as __builtin__
427
+ builtin_print = __builtin__.print
428
+
429
+ def print(*args, **kwargs):
430
+ force = kwargs.pop('force', False)
431
+ if is_master or force:
432
+ builtin_print(*args, **kwargs)
433
+
434
+ __builtin__.print = print
435
+
436
+
437
+ def init_distributed_mode(args):
438
+ # launched with torch.distributed.launch
439
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
440
+ args.rank = int(os.environ["RANK"])
441
+ args.world_size = int(os.environ['WORLD_SIZE'])
442
+ args.gpu = int(os.environ['LOCAL_RANK'])
443
+ # launched with submitit on a slurm cluster
444
+ elif 'SLURM_PROCID' in os.environ:
445
+ args.rank = int(os.environ['SLURM_PROCID'])
446
+ args.gpu = args.rank % torch.cuda.device_count()
447
+ # launched naively with `python main_dino.py`
448
+ # we manually add MASTER_ADDR and MASTER_PORT to env variables
449
+ elif torch.cuda.is_available():
450
+ print('Will run the code on one GPU.')
451
+ args.rank, args.gpu, args.world_size = 0, 0, 1
452
+ os.environ['MASTER_ADDR'] = '127.0.0.1'
453
+ os.environ['MASTER_PORT'] = '29500'
454
+ else:
455
+ print('Does not support training without GPU.')
456
+ sys.exit(1)
457
+
458
+ dist.init_process_group(
459
+ backend="nccl",
460
+ init_method=args.dist_url,
461
+ world_size=args.world_size,
462
+ rank=args.rank,
463
+ )
464
+
465
+ torch.cuda.set_device(args.gpu)
466
+ print('| distributed init (rank {}): {}'.format(
467
+ args.rank, args.dist_url), flush=True)
468
+ dist.barrier()
469
+ setup_for_distributed(args.rank == 0)
470
+
471
+
472
+ def accuracy(output, target, topk=(1,)):
473
+ """Computes the accuracy over the k top predictions for the specified values of k"""
474
+ maxk = max(topk)
475
+ batch_size = target.size(0)
476
+ _, pred = output.topk(maxk, 1, True, True)
477
+ pred = pred.t()
478
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
479
+ return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
480
+
481
+
482
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
483
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
484
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
485
+ def norm_cdf(x):
486
+ # Computes standard normal cumulative distribution function
487
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
488
+
489
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
490
+ warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
491
+ "The distribution of values may be incorrect.",
492
+ stacklevel=2)
493
+
494
+ with torch.no_grad():
495
+ # Values are generated by using a truncated uniform distribution and
496
+ # then using the inverse CDF for the normal distribution.
497
+ # Get upper and lower cdf values
498
+ l = norm_cdf((a - mean) / std)
499
+ u = norm_cdf((b - mean) / std)
500
+
501
+ # Uniformly fill tensor with values from [l, u], then translate to
502
+ # [2l-1, 2u-1].
503
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
504
+
505
+ # Use inverse cdf transform for normal distribution to get truncated
506
+ # standard normal
507
+ tensor.erfinv_()
508
+
509
+ # Transform to proper mean, std
510
+ tensor.mul_(std * math.sqrt(2.))
511
+ tensor.add_(mean)
512
+
513
+ # Clamp to ensure it's in the proper range
514
+ tensor.clamp_(min=a, max=b)
515
+ return tensor
516
+
517
+
518
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
519
+ # type: (Tensor, float, float, float, float) -> Tensor
520
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
521
+
522
+
523
+ class LARS(torch.optim.Optimizer):
524
+ """
525
+ Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
526
+ """
527
+ def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
528
+ weight_decay_filter=None, lars_adaptation_filter=None):
529
+ defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
530
+ eta=eta, weight_decay_filter=weight_decay_filter,
531
+ lars_adaptation_filter=lars_adaptation_filter)
532
+ super().__init__(params, defaults)
533
+
534
+ @torch.no_grad()
535
+ def step(self):
536
+ for g in self.param_groups:
537
+ for p in g['params']:
538
+ dp = p.grad
539
+
540
+ if dp is None:
541
+ continue
542
+
543
+ if p.ndim != 1:
544
+ dp = dp.add(p, alpha=g['weight_decay'])
545
+
546
+ if p.ndim != 1:
547
+ param_norm = torch.norm(p)
548
+ update_norm = torch.norm(dp)
549
+ one = torch.ones_like(param_norm)
550
+ q = torch.where(param_norm > 0.,
551
+ torch.where(update_norm > 0,
552
+ (g['eta'] * param_norm / update_norm), one), one)
553
+ dp = dp.mul(q)
554
+
555
+ param_state = self.state[p]
556
+ if 'mu' not in param_state:
557
+ param_state['mu'] = torch.zeros_like(p)
558
+ mu = param_state['mu']
559
+ mu.mul_(g['momentum']).add_(dp)
560
+
561
+ p.add_(mu, alpha=-g['lr'])
562
+
563
+
564
+ class MultiCropWrapper(nn.Module):
565
+ """
566
+ Perform forward pass separately on each resolution input.
567
+ The inputs corresponding to a single resolution are clubbed and single
568
+ forward is run on the same resolution inputs. Hence we do several
569
+ forward passes = number of different resolutions used. We then
570
+ concatenate all the output features and run the head forward on these
571
+ concatenated features.
572
+ """
573
+ def __init__(self, backbone, head):
574
+ super(MultiCropWrapper, self).__init__()
575
+ # disable layers dedicated to ImageNet labels classification
576
+ backbone.fc, backbone.head = nn.Identity(), nn.Identity()
577
+ self.backbone = backbone
578
+ self.head = head
579
+
580
+ def forward(self, x):
581
+ # convert to list
582
+ if not isinstance(x, list):
583
+ x = [x]
584
+ idx_crops = torch.cumsum(torch.unique_consecutive(
585
+ torch.tensor([inp.shape[-1] for inp in x]),
586
+ return_counts=True,
587
+ )[1], 0)
588
+ start_idx = 0
589
+ for end_idx in idx_crops:
590
+ _out = self.backbone(torch.cat(x[start_idx: end_idx]))
591
+ if start_idx == 0:
592
+ output = _out
593
+ else:
594
+ output = torch.cat((output, _out))
595
+ start_idx = end_idx
596
+ # Run the head forward on the concatenated features.
597
+ return self.head(output)
598
+
599
+
600
+ def get_params_groups(model):
601
+ regularized = []
602
+ not_regularized = []
603
+ for name, param in model.named_parameters():
604
+ if not param.requires_grad:
605
+ continue
606
+ # we do not regularize biases nor Norm parameters
607
+ if name.endswith(".bias") or len(param.shape) == 1:
608
+ not_regularized.append(param)
609
+ else:
610
+ regularized.append(param)
611
+ return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
612
+
613
+
614
+ def has_batchnorms(model):
615
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
616
+ for name, module in model.named_modules():
617
+ if isinstance(module, bn_types):
618
+ return True
619
+ return False
biomap/dino/vision_transformer.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ from functools import partial
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ from dino.utils import trunc_normal_
24
+
25
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
26
+ if drop_prob == 0. or not training:
27
+ return x
28
+ keep_prob = 1 - drop_prob
29
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
30
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
31
+ random_tensor.floor_() # binarize
32
+ output = x.div(keep_prob) * random_tensor
33
+ return output
34
+
35
+
36
+ class DropPath(nn.Module):
37
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
38
+ """
39
+ def __init__(self, drop_prob=None):
40
+ super(DropPath, self).__init__()
41
+ self.drop_prob = drop_prob
42
+
43
+ def forward(self, x):
44
+ return drop_path(x, self.drop_prob, self.training)
45
+
46
+
47
+ class Mlp(nn.Module):
48
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
49
+ super().__init__()
50
+ out_features = out_features or in_features
51
+ hidden_features = hidden_features or in_features
52
+ self.fc1 = nn.Linear(in_features, hidden_features)
53
+ self.act = act_layer()
54
+ self.fc2 = nn.Linear(hidden_features, out_features)
55
+ self.drop = nn.Dropout(drop)
56
+
57
+ def forward(self, x):
58
+ x = self.fc1(x)
59
+ x = self.act(x)
60
+ x = self.drop(x)
61
+ x = self.fc2(x)
62
+ x = self.drop(x)
63
+ return x
64
+
65
+
66
+ class Attention(nn.Module):
67
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
68
+ super().__init__()
69
+ self.num_heads = num_heads
70
+ head_dim = dim // num_heads
71
+ self.scale = qk_scale or head_dim ** -0.5
72
+
73
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
74
+ self.attn_drop = nn.Dropout(attn_drop)
75
+ self.proj = nn.Linear(dim, dim)
76
+ self.proj_drop = nn.Dropout(proj_drop)
77
+
78
+ def forward(self, x, return_qkv=False):
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
81
+ q, k, v = qkv[0], qkv[1], qkv[2]
82
+
83
+ attn = (q @ k.transpose(-2, -1)) * self.scale
84
+ attn = attn.softmax(dim=-1)
85
+ attn = self.attn_drop(attn)
86
+
87
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
88
+ x = self.proj(x)
89
+ x = self.proj_drop(x)
90
+ return x,attn, qkv
91
+
92
+
93
+
94
+ class Block(nn.Module):
95
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
96
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
97
+ super().__init__()
98
+ self.norm1 = norm_layer(dim)
99
+ self.attn = Attention(
100
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
101
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
102
+ self.norm2 = norm_layer(dim)
103
+ mlp_hidden_dim = int(dim * mlp_ratio)
104
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
105
+
106
+ def forward(self, x, return_attention=False, return_qkv = False):
107
+ y, attn, qkv = self.attn(self.norm1(x))
108
+ if return_attention:
109
+ return attn
110
+ x = x + self.drop_path(y)
111
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
112
+ if return_qkv:
113
+ return x,attn, qkv
114
+ return x
115
+
116
+
117
+ class PatchEmbed(nn.Module):
118
+ """ Image to Patch Embedding
119
+ """
120
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
121
+ super().__init__()
122
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
123
+ self.img_size = img_size
124
+ self.patch_size = patch_size
125
+ self.num_patches = num_patches
126
+
127
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
128
+
129
+ def forward(self, x):
130
+ B, C, H, W = x.shape
131
+ x = self.proj(x).flatten(2).transpose(1, 2)
132
+ return x
133
+
134
+
135
+ class VisionTransformer(nn.Module):
136
+ """ Vision Transformer """
137
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
138
+ num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
139
+ drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
140
+ super().__init__()
141
+
142
+ self.num_features = self.embed_dim = embed_dim
143
+
144
+ self.patch_embed = PatchEmbed(
145
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
146
+ num_patches = self.patch_embed.num_patches
147
+
148
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
149
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
150
+ self.pos_drop = nn.Dropout(p=drop_rate)
151
+
152
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
153
+ self.blocks = nn.ModuleList([
154
+ Block(
155
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
156
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
157
+ for i in range(depth)])
158
+ self.norm = norm_layer(embed_dim)
159
+
160
+ # Classifier head
161
+ self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
162
+
163
+ trunc_normal_(self.pos_embed, std=.02)
164
+ trunc_normal_(self.cls_token, std=.02)
165
+ self.apply(self._init_weights)
166
+
167
+ def _init_weights(self, m):
168
+ if isinstance(m, nn.Linear):
169
+ trunc_normal_(m.weight, std=.02)
170
+ if isinstance(m, nn.Linear) and m.bias is not None:
171
+ nn.init.constant_(m.bias, 0)
172
+ elif isinstance(m, nn.LayerNorm):
173
+ nn.init.constant_(m.bias, 0)
174
+ nn.init.constant_(m.weight, 1.0)
175
+
176
+ def interpolate_pos_encoding(self, x, w, h):
177
+ npatch = x.shape[1] - 1
178
+ N = self.pos_embed.shape[1] - 1
179
+ if npatch == N and w == h:
180
+ return self.pos_embed
181
+ class_pos_embed = self.pos_embed[:, 0]
182
+ patch_pos_embed = self.pos_embed[:, 1:]
183
+ dim = x.shape[-1]
184
+ w0 = w // self.patch_embed.patch_size
185
+ h0 = h // self.patch_embed.patch_size
186
+ # we add a small number to avoid floating point error in the interpolation
187
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
188
+ w0, h0 = w0 + 0.1, h0 + 0.1
189
+ patch_pos_embed = nn.functional.interpolate(
190
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
191
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
192
+ mode='bicubic',
193
+ )
194
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
195
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
196
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
197
+
198
+ def prepare_tokens(self, x):
199
+ B, nc, w, h = x.shape
200
+ x = self.patch_embed(x) # patch linear embedding
201
+
202
+ # add the [CLS] token to the embed patch tokens
203
+ cls_tokens = self.cls_token.expand(B, -1, -1)
204
+ x = torch.cat((cls_tokens, x), dim=1)
205
+
206
+ # add positional encoding to each token
207
+ x = x + self.interpolate_pos_encoding(x, w, h)
208
+
209
+ return self.pos_drop(x)
210
+
211
+ def forward(self, x):
212
+ x = self.prepare_tokens(x)
213
+ for blk in self.blocks:
214
+ x = blk(x)
215
+ x = self.norm(x)
216
+ return x[:, 0]
217
+
218
+ def forward_feats(self, x):
219
+ x = self.prepare_tokens(x)
220
+ for blk in self.blocks:
221
+ x = blk(x)
222
+ x = self.norm(x)
223
+ return x
224
+
225
+ def get_intermediate_feat(self, x, n=1):
226
+ x = self.prepare_tokens(x)
227
+ # we return the output tokens from the `n` last blocks
228
+ feat = []
229
+ attns = []
230
+ qkvs = []
231
+ for i, blk in enumerate(self.blocks):
232
+ x,attn,qkv = blk(x, return_qkv=True)
233
+ if len(self.blocks) - i <= n:
234
+ feat.append(self.norm(x))
235
+ qkvs.append(qkv)
236
+ attns.append(attn)
237
+ return feat, attns, qkvs
238
+
239
+ def get_last_selfattention(self, x):
240
+ x = self.prepare_tokens(x)
241
+ for i, blk in enumerate(self.blocks):
242
+ if i < len(self.blocks) - 1:
243
+ x = blk(x)
244
+ else:
245
+ # return attention of the last block
246
+ return blk(x, return_attention=True)
247
+
248
+ def get_intermediate_layers(self, x, n=1):
249
+ x = self.prepare_tokens(x)
250
+ # we return the output tokens from the `n` last blocks
251
+ output = []
252
+ for i, blk in enumerate(self.blocks):
253
+ x = blk(x)
254
+ if len(self.blocks) - i <= n:
255
+ output.append(self.norm(x))
256
+ return output
257
+
258
+
259
+ def vit_tiny(patch_size=16, **kwargs):
260
+ model = VisionTransformer(
261
+ patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
262
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
263
+ return model
264
+
265
+
266
+ def vit_small(patch_size=16, **kwargs):
267
+ model = VisionTransformer(
268
+ patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
269
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
270
+ return model
271
+
272
+
273
+ def vit_base(patch_size=16, **kwargs):
274
+ model = VisionTransformer(
275
+ patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
276
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
277
+ return model
278
+
279
+
280
+ class DINOHead(nn.Module):
281
+ def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
282
+ super().__init__()
283
+ nlayers = max(nlayers, 1)
284
+ if nlayers == 1:
285
+ self.mlp = nn.Linear(in_dim, bottleneck_dim)
286
+ else:
287
+ layers = [nn.Linear(in_dim, hidden_dim)]
288
+ if use_bn:
289
+ layers.append(nn.BatchNorm1d(hidden_dim))
290
+ layers.append(nn.GELU())
291
+ for _ in range(nlayers - 2):
292
+ layers.append(nn.Linear(hidden_dim, hidden_dim))
293
+ if use_bn:
294
+ layers.append(nn.BatchNorm1d(hidden_dim))
295
+ layers.append(nn.GELU())
296
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim))
297
+ self.mlp = nn.Sequential(*layers)
298
+ self.apply(self._init_weights)
299
+ self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
300
+ self.last_layer.weight_g.data.fill_(1)
301
+ if norm_last_layer:
302
+ self.last_layer.weight_g.requires_grad = False
303
+
304
+ def _init_weights(self, m):
305
+ if isinstance(m, nn.Linear):
306
+ trunc_normal_(m.weight, std=.02)
307
+ if isinstance(m, nn.Linear) and m.bias is not None:
308
+ nn.init.constant_(m.bias, 0)
309
+
310
+ def forward(self, x):
311
+ x = self.mlp(x)
312
+ x = nn.functional.normalize(x, dim=-1, p=2)
313
+ x = self.last_layer(x)
314
+ return x
biomap/helper.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.multiprocessing
2
+ import torchvision.transforms as T
3
+ import numpy as np
4
+ from utils import transform_to_pil, create_video
5
+ from utils_gee import extract_img, transform_ee_img
6
+ from dateutil.relativedelta import relativedelta
7
+ import datetime
8
+ from dateutil.relativedelta import relativedelta
9
+ import cv2
10
+
11
+ from joblib import Parallel, cpu_count, delayed
12
+
13
+ def get_image(location, d1, d2):
14
+ print(f"getting image for {d1} to {d2}")
15
+ try:
16
+ img = extract_img(location, d1, d2)
17
+ img_test = transform_ee_img(
18
+ img, max=0.3
19
+ )
20
+ return img_test
21
+ except Exception as err:
22
+ print(err)
23
+ return
24
+
25
+
26
+ def inference_on_location(model, latitude = 2.98, longitude = 48.81, start_date=2020, end_date=2022):
27
+ """Performe an inference on the latitude and longitude between the start date and the end date
28
+
29
+ Args:
30
+ latitude (float): the latitude of the landscape
31
+ longitude (float): the longitude of the landscape
32
+ start_date (str): the start date for our inference
33
+ end_date (str): the end date for our inference
34
+ model (_type_, optional): _description_. Defaults to model.
35
+
36
+ Returns:
37
+ img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
38
+ """
39
+ assert end_date > start_date, "end date must be stricly higher than start date"
40
+ location = [float(latitude), float(longitude)]
41
+
42
+ # Extract img numpy from earth engine and transform it to PIL img
43
+ dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
44
+ while dates[-1] < datetime.datetime(end_date, 1, 1, 0, 0, 0):
45
+ dates.append(dates[-1] + relativedelta(months=1))
46
+
47
+ dates = [d.strftime("%Y-%m-%d") for d in dates]
48
+
49
+ all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
50
+ all_image = [image for image in all_image if image is not None]
51
+
52
+
53
+
54
+
55
+ # tensorize & normalize img
56
+ preprocess = T.Compose(
57
+ [
58
+ T.ToPILImage(),
59
+ T.Resize((320, 320)),
60
+ # T.CenterCrop(224),
61
+ T.ToTensor(),
62
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
63
+ ]
64
+ )
65
+
66
+ # Preprocess opened img
67
+ x = torch.stack([preprocess(imag) for imag in all_image]).cpu()
68
+
69
+ # launch inference on cpu
70
+ # x = torch.unsqueeze(x, dim=0).cpu()
71
+ model = model.cpu()
72
+
73
+ with torch.no_grad():
74
+ feats, code = model.net(x)
75
+ linear_pred = model.linear_probe(x, code)
76
+ linear_pred = linear_pred.argmax(1)
77
+ outputs = [{
78
+ "img": torch.unsqueeze(img, dim=0).detach().cpu(),
79
+ "linear_preds": torch.unsqueeze(linear_pred, dim=0).detach().cpu(),
80
+ } for img, linear_pred in zip(x, linear_pred)]
81
+ all_img = []
82
+ all_label = []
83
+ all_labeled_img = []
84
+ for output in outputs:
85
+ img, label, labeled_img = transform_to_pil(output)
86
+ all_img.append(img)
87
+ all_label.append(label)
88
+ all_labeled_img.append(labeled_img)
89
+
90
+ all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_labeled_img]
91
+ create_video(all_labeled_img, output_path='output/output.mp4')
92
+
93
+ # all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_img]
94
+ # create_video(all_labeled_img, output_path='raw.mp4')
95
+
96
+ return 'output.mp4'
97
+
98
+ def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'):
99
+ """Performe an inference on the latitude and longitude between the start date and the end date
100
+
101
+ Args:
102
+ latitude (float): the latitude of the landscape
103
+ longitude (float): the longitude of the landscape
104
+ start_date (str): the start date for our inference
105
+ end_date (str): the end date for our inference
106
+ model (_type_, optional): _description_. Defaults to model.
107
+
108
+ Returns:
109
+ img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
110
+ """
111
+ location = [float(latitude), float(longitude)]
112
+
113
+ # Extract img numpy from earth engine and transform it to PIL img
114
+ end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
115
+ end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
116
+ img = extract_img(location, start_date, end_date)
117
+ img_test = transform_ee_img(
118
+ img, max=0.3
119
+ ) # max value is the value from numpy file that will be equal to 255
120
+
121
+ # tensorize & normalize img
122
+ preprocess = T.Compose(
123
+ [
124
+ T.ToPILImage(),
125
+ T.Resize((320, 320)),
126
+ # T.CenterCrop(224),
127
+ T.ToTensor(),
128
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
129
+ ]
130
+ )
131
+
132
+ # Preprocess opened img
133
+ x = preprocess(img_test)
134
+
135
+ # launch inference on cpu
136
+ x = torch.unsqueeze(x, dim=0).cpu()
137
+ model = model.cpu()
138
+
139
+ with torch.no_grad():
140
+ feats, code = model.net(x)
141
+ linear_pred = model.linear_probe(x, code)
142
+ linear_pred = linear_pred.argmax(1)
143
+ output = {
144
+ "img": x[: model.cfg.n_images].detach().cpu(),
145
+ "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
146
+ }
147
+ nb_values = []
148
+ for i in range(7):
149
+ nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
150
+ scores_init = [2,3,4,3,1,4,0]
151
+ score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
152
+
153
+ img, label, labeled_img = transform_to_pil(output)
154
+ return img, labeled_img,score
155
+
156
+
157
+ if __name__ == "__main__":
158
+ import logging
159
+ import hydra
160
+
161
+
162
+ from model import LitUnsupervisedSegmenter
163
+ logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
164
+ # Initialize hydra with configs
165
+ hydra.initialize(config_path="configs", job_name="corine")
166
+ cfg = hydra.compose(config_name="my_train_config.yml")
167
+ logging.info(f"config : {cfg}")
168
+ # Load the model
169
+
170
+ nbclasses = cfg.dir_dataset_n_classes
171
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
172
+ logging.info(f"Model Initialiazed")
173
+
174
+ model_path = "checkpoint/model/model.pt"
175
+ saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
176
+ logging.info(f"Model weights Loaded")
177
+ model.load_state_dict(saved_state_dict)
178
+ logging.info(f"Model Loaded")
179
+ inference_on_location(model)
biomap/inference.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.multiprocessing
2
+ import torchvision.transforms as T
3
+ from utils import transform_to_pil
4
+
5
+ def inference(image, model):
6
+ # tensorize & normalize img
7
+ preprocess = T.Compose(
8
+ [
9
+ T.ToPILImage(),
10
+ T.Resize((320, 320)),
11
+ # T.CenterCrop(224),
12
+ T.ToTensor(),
13
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
14
+ ]
15
+ )
16
+
17
+ # Preprocess opened img
18
+ x = preprocess(image)
19
+
20
+ # launch inference on cpu
21
+ x = torch.unsqueeze(x, dim=0).cpu()
22
+ model = model.cpu()
23
+
24
+ with torch.no_grad():
25
+ feats, code = model.net(x)
26
+ linear_pred = model.linear_probe(x, code)
27
+ linear_pred = linear_pred.argmax(1)
28
+ output = {
29
+ "img": x[: model.cfg.n_images].detach().cpu(),
30
+ "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
31
+ }
32
+
33
+ img, label, labeled_img = transform_to_pil(output)
34
+ return img, labeled_img, label
35
+
36
+
37
+ if __name__ == "__main__":
38
+ import hydra
39
+ from model import LitUnsupervisedSegmenter
40
+ from utils_gee import extract_img, transform_ee_img
41
+ latitude = 2.98
42
+ longitude = 48.81
43
+ start_date = '2020-03-20'
44
+ end_date = '2020-04-20'
45
+
46
+ location = [float(latitude), float(longitude)]
47
+ # Extract img numpy from earth engine and transform it to PIL img
48
+ img = extract_img(location, start_date, end_date)
49
+ image = transform_ee_img(
50
+ img, max=0.3
51
+ ) # max value is the value from numpy file that will be equal to 255
52
+ print("image loaded")
53
+ # Initialize hydra with configs
54
+ hydra.initialize(config_path="configs", job_name="corine")
55
+ cfg = hydra.compose(config_name="my_train_config.yml")
56
+
57
+ # Load the model
58
+ model_path = "checkpoint/model/model.pt"
59
+ saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
60
+
61
+ nbclasses = cfg.dir_dataset_n_classes
62
+
63
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
64
+ print("model initialized")
65
+ model.load_state_dict(saved_state_dict)
66
+ print("model loaded")
67
+ # img.save("output/image.png")
68
+ img, labeled_img, label = inference(image, model)
69
+ img.save("output/img.png")
70
+ label.save("output/label.png")
71
+ labeled_img.save("output/labeled_img.png")
72
+
73
+
74
+
75
+
76
+ # def get_list_date(start_date, end_date):
77
+ # """Get all the date between the start date and the end date
78
+
79
+ # Args:
80
+ # start_date (str): start date at the format '%Y-%m-%d'
81
+ # end_date (str): end date at the format '%Y-%m-%d'
82
+
83
+ # Returns:
84
+ # list[str]: all the date between the start date and the end date
85
+ # """
86
+ # start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
87
+ # end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
88
+ # list_date = [start_date]
89
+ # date = start_date
90
+ # while date < end_date:
91
+ # date = date + datetime.timedelta(days=1)
92
+ # list_date.append(date)
93
+ # list_date.append(end_date)
94
+ # list_date2 = [x.strftime("%Y-%m-%d") for x in list_date]
95
+ # return list_date2
96
+
97
+
98
+ # def get_length_interval(start_date, end_date):
99
+ # """Return how many days there is between the start date and the end date
100
+
101
+ # Args:
102
+ # start_date (str): start date at the format '%Y-%m-%d'
103
+ # end_date (str): end date at the format '%Y-%m-%d'
104
+
105
+ # Returns:
106
+ # int : number of days between start date and the end date
107
+ # """
108
+ # try:
109
+ # return len(get_list_date(start_date, end_date))
110
+ # except ValueError:
111
+ # return 0
112
+
113
+
114
+ # def infer_unique_date(latitude, longitude, date, model=model):
115
+ # """Perform an inference on a latitude and a longitude at a specific date
116
+
117
+ # Args:
118
+ # latitude (float): the latitude of the landscape
119
+ # longitude (float): the longitude of the landscape
120
+ # date (str): date for the inference at the format '%Y-%m-%d'
121
+ # model (_type_, optional): _description_. Defaults to model.
122
+
123
+ # Returns:
124
+ # img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
125
+ # """
126
+ # start_date = date
127
+ # end_date = date
128
+ # location = [float(latitude), float(longitude)]
129
+ # # Extract img numpy from earth engine and transform it to PIL img
130
+ # img = extract_img(location, start_date, end_date)
131
+ # img_test = transform_ee_img(
132
+ # img, max=0.3
133
+ # ) # max value is the value from numpy file that will be equal to 255
134
+
135
+ # # tensorize & normalize img
136
+ # preprocess = T.Compose(
137
+ # [
138
+ # T.ToPILImage(),
139
+ # T.Resize((320, 320)),
140
+ # # T.CenterCrop(224),
141
+ # T.ToTensor(),
142
+ # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
143
+ # ]
144
+ # )
145
+
146
+ # # Preprocess opened img
147
+ # x = preprocess(img_test)
148
+
149
+ # # launch inference on cpu
150
+ # x = torch.unsqueeze(x, dim=0).cpu()
151
+ # model = model.cpu()
152
+
153
+ # with torch.no_grad():
154
+ # feats, code = model.net(x)
155
+ # linear_pred = model.linear_probe(x, code)
156
+ # linear_pred = linear_pred.argmax(1)
157
+ # output = {
158
+ # "img": x[: model.cfg.n_images].detach().cpu(),
159
+ # "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
160
+ # }
161
+
162
+ # img, label, labeled_img = transform_to_pil(output)
163
+ # biodiv_score = compute_biodiv_score(labeled_img)
164
+ # return img, labeled_img, biodiv_score
165
+
166
+
167
+ # def get_img_array(start_date, end_date, latitude, longitude, model=model):
168
+ # list_date = get_list_date(start_date, end_date)
169
+ # list_img = []
170
+ # for date in list_date:
171
+ # list_img.append(img)
172
+ # return list_img
173
+
174
+
175
+ # def variable_outputs(start_date, end_date, latitude, longitude, day, model=model):
176
+ # """Perform an inference on the day number day starting from the start at the latitude and longitude selected
177
+
178
+ # Args:
179
+ # latitude (float): the latitude of the landscape
180
+ # longitude (float): the longitude of the landscape
181
+ # start_date (str): the start date for our inference
182
+ # end_date (str): the end date for our inference
183
+ # model (_type_, optional): _description_. Defaults to model.
184
+
185
+ # Returns:
186
+ # img,labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
187
+ # """
188
+ # list_date = get_list_date(start_date, end_date)
189
+ # k = int(day)
190
+ # date = list_date[k]
191
+ # img, labeled_img, biodiv_score = infer_unique_date(
192
+ # latitude, longitude, date, model=model
193
+ # )
194
+ # return img, labeled_img, biodiv_score
195
+
196
+
197
+ # def variable_outputs2(
198
+ # start_date, end_date, latitude, longitude, day_number, model=model
199
+ # ):
200
+ # """Perform an inference on the day number day starting from the start at the latitude and longitude selected
201
+
202
+ # Args:
203
+ # latitude (float): the latitude of the landscape
204
+ # longitude (float): the longitude of the landscape
205
+ # start_date (str): the start date for our inference
206
+ # end_date (str): the end date for our inference
207
+ # model (_type_, optional): _description_. Defaults to model.
208
+
209
+ # Returns:
210
+ # list[img,labeled_img,biodiv_score]: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
211
+ # """
212
+ # list_date = get_list_date(start_date, end_date)
213
+ # k = int(day_number)
214
+ # date = list_date[k]
215
+ # img, labeled_img, biodiv_score = infer_unique_date(
216
+ # latitude, longitude, date, model=model
217
+ # )
218
+ # return [img, labeled_img, biodiv_score]
219
+
220
+
221
+ # def gif_maker(img_array):
222
+ # output_file = "test2.mkv"
223
+ # image_test = img_array[0]
224
+ # size = (320, 320)
225
+ # print(size)
226
+ # out = cv2.VideoWriter(
227
+ # output_file, cv2.VideoWriter_fourcc(*"avc1"), 15, frameSize=size
228
+ # )
229
+ # for i in range(len(img_array)):
230
+ # image = img_array[i]
231
+ # pix = np.array(image.getdata())
232
+ # out.write(pix)
233
+ # out.release()
234
+ # return output_file
235
+
236
+
237
+ # def infer_multiple_date(start_date, end_date, latitude, longitude, model=model):
238
+ # """Perform an inference on all the dates between the start date and the end date at the latitude and longitude
239
+
240
+ # Args:
241
+ # latitude (float): the latitude of the landscape
242
+ # longitude (float): the longitude of the landscape
243
+ # start_date (str): the start date for our inference
244
+ # end_date (str): the end date for our inference
245
+ # model (_type_, optional): _description_. Defaults to model.
246
+
247
+ # Returns:
248
+ # list_img,list_labeled_img,list_biodiv_score: list of the original landscape, the labeled landscape and the biodiversity score and the landscape
249
+ # """
250
+ # list_date = get_list_date(start_date, end_date)
251
+ # list_img = []
252
+ # list_labeled_img = []
253
+ # list_biodiv_score = []
254
+ # for date in list_date:
255
+ # img, labeled_img, biodiv_score = infer_unique_date(
256
+ # latitude, longitude, date, model=model
257
+ # )
258
+ # list_img.append(img)
259
+ # list_labeled_img.append(labeled_img)
260
+ # list_biodiv_score.append(biodiv_score)
261
+ # return gif_maker(list_img), gif_maker(list_labeled_img), list_biodiv_score[0]
biomap/label.png ADDED
biomap/model.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from modules import *
3
+ from data import *
4
+ import torch.nn.functional as F
5
+ import pytorch_lightning as pl
6
+ import torch.multiprocessing
7
+ import seaborn as sns
8
+ import unet
9
+
10
+ class LitUnsupervisedSegmenter(pl.LightningModule):
11
+ def __init__(self, n_classes, cfg):
12
+ super().__init__()
13
+ self.cfg = cfg
14
+ self.n_classes = n_classes
15
+
16
+ if not cfg.continuous:
17
+ dim = n_classes
18
+ else:
19
+ dim = cfg.dim
20
+
21
+ data_dir = join(cfg.output_root, "data")
22
+ if cfg.arch == "feature-pyramid":
23
+ cut_model = load_model(cfg.model_type, data_dir).cuda()
24
+ self.net = FeaturePyramidNet(
25
+ cfg.granularity, cut_model, dim, cfg.continuous
26
+ )
27
+ elif cfg.arch == "dino":
28
+ self.net = DinoFeaturizer(dim, cfg)
29
+ else:
30
+ raise ValueError("Unknown arch {}".format(cfg.arch))
31
+
32
+ self.train_cluster_probe = ClusterLookup(dim, n_classes)
33
+
34
+ self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters)
35
+ # self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1))
36
+ # self.linear_probe = nn.Sequential(OrderedDict([
37
+ # ('conv1', nn.Conv2d(dim, 2*n_classes, (7, 7), padding='same')),
38
+ # ('relu1', nn.ReLU()),
39
+ # ('conv2', nn.Conv2d(2*n_classes, n_classes, (3, 3), padding='same'))
40
+ # ]))
41
+ self.linear_probe = unet.AuxUNet(
42
+ enc_chs=(3, 32, 64, 128, 256),
43
+ dec_chs=(256, 128, 64, 32),
44
+ aux_ch=70,
45
+ num_class=n_classes,
46
+ )
47
+
48
+ self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1))
49
+
50
+ self.cluster_metrics = UnsupervisedMetrics(
51
+ "test/cluster/", n_classes, cfg.extra_clusters, True
52
+ )
53
+ self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False)
54
+
55
+ self.test_cluster_metrics = UnsupervisedMetrics(
56
+ "final/cluster/", n_classes, cfg.extra_clusters, True
57
+ )
58
+ self.test_linear_metrics = UnsupervisedMetrics(
59
+ "final/linear/", n_classes, 0, False
60
+ )
61
+
62
+ self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss()
63
+ self.crf_loss_fn = ContrastiveCRFLoss(
64
+ cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift
65
+ )
66
+
67
+ self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg)
68
+ for p in self.contrastive_corr_loss_fn.parameters():
69
+ p.requires_grad = False
70
+
71
+ self.automatic_optimization = False
72
+
73
+ if self.cfg.dataset_name.startswith("cityscapes"):
74
+ self.label_cmap = create_cityscapes_colormap()
75
+ else:
76
+ self.label_cmap = create_pascal_label_colormap()
77
+
78
+ self.val_steps = 0
79
+ self.save_hyperparameters()
80
+
81
+ def forward(self, x):
82
+ # in lightning, forward defines the prediction/inference actions
83
+ return self.net(x)[1]
84
+
85
+ def training_step(self, batch, batch_idx):
86
+ # training_step defined the train loop.
87
+ # It is independent of forward
88
+ net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers()
89
+
90
+ net_optim.zero_grad()
91
+ linear_probe_optim.zero_grad()
92
+ cluster_probe_optim.zero_grad()
93
+
94
+ with torch.no_grad():
95
+ ind = batch["ind"]
96
+ img = batch["img"]
97
+ img_aug = batch["img_aug"]
98
+ coord_aug = batch["coord_aug"]
99
+ img_pos = batch["img_pos"]
100
+ label = batch["label"]
101
+ label_pos = batch["label_pos"]
102
+
103
+ feats, code = self.net(img)
104
+ if self.cfg.correspondence_weight > 0:
105
+ feats_pos, code_pos = self.net(img_pos)
106
+ log_args = dict(sync_dist=False, rank_zero_only=True)
107
+
108
+ if self.cfg.use_true_labels:
109
+ signal = one_hot_feats(label + 1, self.n_classes + 1)
110
+ signal_pos = one_hot_feats(label_pos + 1, self.n_classes + 1)
111
+ else:
112
+ signal = feats
113
+ signal_pos = feats_pos
114
+
115
+ loss = 0
116
+
117
+ should_log_hist = (
118
+ (self.cfg.hist_freq is not None)
119
+ and (self.global_step % self.cfg.hist_freq == 0)
120
+ and (self.global_step > 0)
121
+ )
122
+ if self.cfg.use_salience:
123
+ salience = batch["mask"].to(torch.float32).squeeze(1)
124
+ salience_pos = batch["mask_pos"].to(torch.float32).squeeze(1)
125
+ else:
126
+ salience = None
127
+ salience_pos = None
128
+
129
+ if self.cfg.correspondence_weight > 0:
130
+ (
131
+ pos_intra_loss,
132
+ pos_intra_cd,
133
+ pos_inter_loss,
134
+ pos_inter_cd,
135
+ neg_inter_loss,
136
+ neg_inter_cd,
137
+ ) = self.contrastive_corr_loss_fn(
138
+ signal,
139
+ signal_pos,
140
+ salience,
141
+ salience_pos,
142
+ code,
143
+ code_pos,
144
+ )
145
+
146
+ if should_log_hist:
147
+ self.logger.experiment.add_histogram(
148
+ "intra_cd", pos_intra_cd, self.global_step
149
+ )
150
+ self.logger.experiment.add_histogram(
151
+ "inter_cd", pos_inter_cd, self.global_step
152
+ )
153
+ self.logger.experiment.add_histogram(
154
+ "neg_cd", neg_inter_cd, self.global_step
155
+ )
156
+ neg_inter_loss = neg_inter_loss.mean()
157
+ pos_intra_loss = pos_intra_loss.mean()
158
+ pos_inter_loss = pos_inter_loss.mean()
159
+ self.log("loss/pos_intra", pos_intra_loss, **log_args)
160
+ self.log("loss/pos_inter", pos_inter_loss, **log_args)
161
+ self.log("loss/neg_inter", neg_inter_loss, **log_args)
162
+ self.log("cd/pos_intra", pos_intra_cd.mean(), **log_args)
163
+ self.log("cd/pos_inter", pos_inter_cd.mean(), **log_args)
164
+ self.log("cd/neg_inter", neg_inter_cd.mean(), **log_args)
165
+
166
+ loss += (
167
+ self.cfg.pos_inter_weight * pos_inter_loss
168
+ + self.cfg.pos_intra_weight * pos_intra_loss
169
+ + self.cfg.neg_inter_weight * neg_inter_loss
170
+ ) * self.cfg.correspondence_weight
171
+
172
+ if self.cfg.rec_weight > 0:
173
+ rec_feats = self.decoder(code)
174
+ rec_loss = -(norm(rec_feats) * norm(feats)).sum(1).mean()
175
+ self.log("loss/rec", rec_loss, **log_args)
176
+ loss += self.cfg.rec_weight * rec_loss
177
+
178
+ if self.cfg.aug_alignment_weight > 0:
179
+ orig_feats_aug, orig_code_aug = self.net(img_aug)
180
+ downsampled_coord_aug = resize(
181
+ coord_aug.permute(0, 3, 1, 2), orig_code_aug.shape[2]
182
+ ).permute(0, 2, 3, 1)
183
+ aug_alignment = -torch.einsum(
184
+ "bkhw,bkhw->bhw",
185
+ norm(sample(code, downsampled_coord_aug)),
186
+ norm(orig_code_aug),
187
+ ).mean()
188
+ self.log("loss/aug_alignment", aug_alignment, **log_args)
189
+ loss += self.cfg.aug_alignment_weight * aug_alignment
190
+
191
+ if self.cfg.crf_weight > 0:
192
+ crf = self.crf_loss_fn(resize(img, 56), norm(resize(code, 56))).mean()
193
+ self.log("loss/crf", crf, **log_args)
194
+ loss += self.cfg.crf_weight * crf
195
+
196
+ flat_label = label.reshape(-1)
197
+ mask = (flat_label >= 0) & (flat_label < self.n_classes)
198
+
199
+ detached_code = torch.clone(code.detach())
200
+
201
+ # pdb.set_trace()
202
+
203
+ linear_logits = self.linear_probe(img, detached_code)
204
+ linear_logits = F.interpolate(
205
+ linear_logits, label.shape[-2:], mode="bilinear", align_corners=False
206
+ )
207
+ linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes)
208
+ linear_loss = self.linear_probe_loss_fn(
209
+ linear_logits[mask], flat_label[mask]
210
+ ).mean()
211
+ loss += linear_loss
212
+ self.log("loss/linear", linear_loss, **log_args)
213
+
214
+ cluster_loss, cluster_probs = self.cluster_probe(detached_code, None)
215
+ loss += cluster_loss
216
+ self.log("loss/cluster", cluster_loss, **log_args)
217
+ self.log("loss/total", loss, **log_args)
218
+
219
+ self.manual_backward(loss)
220
+ net_optim.step()
221
+ cluster_probe_optim.step()
222
+ linear_probe_optim.step()
223
+
224
+ if (
225
+ self.cfg.reset_probe_steps is not None
226
+ and self.global_step == self.cfg.reset_probe_steps
227
+ ):
228
+ print("RESETTING PROBES")
229
+ self.linear_probe.reset_parameters()
230
+ self.cluster_probe.reset_parameters()
231
+ self.trainer.optimizers[1] = torch.optim.Adam(
232
+ list(self.linear_probe.parameters()), lr=5e-3
233
+ )
234
+ self.trainer.optimizers[2] = torch.optim.Adam(
235
+ list(self.cluster_probe.parameters()), lr=5e-3
236
+ )
237
+
238
+ if self.global_step % 2000 == 0 and self.global_step > 0:
239
+ print("RESETTING TFEVENT FILE")
240
+ # Make a new tfevent file
241
+ self.logger.experiment.close()
242
+ self.logger.experiment._get_file_writer()
243
+
244
+ return loss
245
+
246
+ def on_train_start(self):
247
+ tb_metrics = {**self.linear_metrics.compute(), **self.cluster_metrics.compute()}
248
+ self.logger.log_hyperparams(self.cfg, tb_metrics)
249
+
250
+ def validation_step(self, batch, batch_idx):
251
+ img = batch["img"]
252
+ label = batch["label"]
253
+ self.net.eval()
254
+
255
+ with torch.no_grad():
256
+ feats, code = self.net(img)
257
+
258
+ # code = F.interpolate(code, label.shape[-2:], mode='bilinear', align_corners=False)
259
+ # linear_preds = self.linear_probe(code)
260
+ linear_preds = self.linear_probe(img, code)
261
+ linear_preds = linear_preds.argmax(1)
262
+ self.linear_metrics.update(linear_preds, label)
263
+
264
+ code = F.interpolate(
265
+ code, label.shape[-2:], mode="bilinear", align_corners=False
266
+ )
267
+ cluster_loss, cluster_preds = self.cluster_probe(code, None)
268
+ cluster_preds = cluster_preds.argmax(1)
269
+ self.cluster_metrics.update(cluster_preds, label)
270
+
271
+ return {
272
+ "img": img[: self.cfg.n_images].detach().cpu(),
273
+ "linear_preds": linear_preds[: self.cfg.n_images].detach().cpu(),
274
+ "cluster_preds": cluster_preds[: self.cfg.n_images].detach().cpu(),
275
+ "label": label[: self.cfg.n_images].detach().cpu(),
276
+ }
277
+
278
+ def validation_epoch_end(self, outputs) -> None:
279
+ super().validation_epoch_end(outputs)
280
+ with torch.no_grad():
281
+ tb_metrics = {
282
+ **self.linear_metrics.compute(),
283
+ **self.cluster_metrics.compute(),
284
+ }
285
+
286
+ if self.trainer.is_global_zero and not self.cfg.submitting_to_aml:
287
+ # output_num = 0
288
+ output_num = random.randint(0, len(outputs) - 1)
289
+ output = {k: v.detach().cpu() for k, v in outputs[output_num].items()}
290
+
291
+ # pdb.set_trace()
292
+ alpha = 0.4
293
+ n_rows = 6
294
+ fig, ax = plt.subplots(
295
+ n_rows,
296
+ self.cfg.n_images,
297
+ figsize=(self.cfg.n_images * 3, n_rows * 3),
298
+ )
299
+ for i in range(self.cfg.n_images):
300
+ try:
301
+ rbg_img = prep_for_plot(output["img"][i])
302
+ true_label = output["label"].squeeze()[i]
303
+ true_label[true_label == -1] = 7
304
+ except:
305
+ continue
306
+ # ax[0, i].imshow(prep_for_plot(output["img"][i]))
307
+ # ax[1, i].imshow(self.label_cmap[output["label"].squeeze()[i]])
308
+ # ax[2, i].imshow(self.label_cmap[output["linear_preds"][i]])
309
+ # ax[3, i].imshow(self.label_cmap[self.cluster_metrics.map_clusters(output["cluster_preds"][i])])
310
+ ax[0, i].imshow(rbg_img)
311
+
312
+ ax[1, i].imshow(rbg_img)
313
+ ax[1, i].imshow(true_label, alpha=alpha, cmap=cmap, norm=norm)
314
+
315
+ ax[2, i].imshow(rbg_img)
316
+ pred_label = output["linear_preds"][i]
317
+ ax[2, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
318
+
319
+ ax[3, i].imshow(rbg_img)
320
+ retouched_label = retouch_label(pred_label.numpy(), true_label)
321
+ ax[3, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
322
+
323
+ ax[4, i].imshow(rbg_img)
324
+ pred_label = self.cluster_metrics.map_clusters(
325
+ output["cluster_preds"][i]
326
+ )
327
+ ax[4, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
328
+ # ax[3, i].imshow(map_clusters_with_label(true_label, pred_label), alpha=0.5, cmap=cmap, norm=norm)
329
+
330
+ ax[5, i].imshow(rbg_img)
331
+ retouched_label = retouch_label(pred_label.numpy(), true_label)
332
+ ax[5, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
333
+
334
+ ax[0, 0].set_ylabel("Image", fontsize=16)
335
+ ax[1, 0].set_ylabel("Label", fontsize=16)
336
+ ax[2, 0].set_ylabel("UNet Probe", fontsize=16)
337
+ ax[3, 0].set_ylabel("Retouched UNet Probe", fontsize=16)
338
+ ax[4, 0].set_ylabel("Cluster Probe", fontsize=16)
339
+ ax[5, 0].set_ylabel("Retouched cluster Probe", fontsize=16)
340
+ remove_axes(ax)
341
+ plt.tight_layout()
342
+ add_plot(self.logger.experiment, "plot_labels", self.global_step)
343
+
344
+ if self.cfg.has_labels:
345
+ fig = plt.figure(figsize=(13, 10))
346
+ ax = fig.gca()
347
+ hist = (
348
+ self.cluster_metrics.histogram.detach().cpu().to(torch.float32)
349
+ )
350
+ hist /= torch.clamp_min(hist.sum(dim=0, keepdim=True), 1)
351
+ sns.heatmap(hist.t(), annot=False, fmt="g", ax=ax, cmap="Blues")
352
+ ax.set_xlabel("Predicted labels")
353
+ ax.set_ylabel("True labels")
354
+ names = get_class_labels(self.cfg.dataset_name)
355
+ if self.cfg.extra_clusters:
356
+ names = names + ["Extra"]
357
+ ax.set_xticks(np.arange(0, len(names)) + 0.5)
358
+ ax.set_yticks(np.arange(0, len(names)) + 0.5)
359
+ ax.xaxis.tick_top()
360
+ ax.xaxis.set_ticklabels(names, fontsize=14)
361
+ ax.yaxis.set_ticklabels(names, fontsize=14)
362
+ colors = [self.label_cmap[i] / 255.0 for i in range(len(names))]
363
+ [
364
+ t.set_color(colors[i])
365
+ for i, t in enumerate(ax.xaxis.get_ticklabels())
366
+ ]
367
+ [
368
+ t.set_color(colors[i])
369
+ for i, t in enumerate(ax.yaxis.get_ticklabels())
370
+ ]
371
+ # ax.yaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
372
+ # ax.xaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
373
+ plt.xticks(rotation=90)
374
+ plt.yticks(rotation=0)
375
+ ax.vlines(
376
+ np.arange(0, len(names) + 1),
377
+ color=[0.5, 0.5, 0.5],
378
+ *ax.get_xlim()
379
+ )
380
+ ax.hlines(
381
+ np.arange(0, len(names) + 1),
382
+ color=[0.5, 0.5, 0.5],
383
+ *ax.get_ylim()
384
+ )
385
+ plt.tight_layout()
386
+ add_plot(self.logger.experiment, "conf_matrix", self.global_step)
387
+
388
+ all_bars = torch.cat(
389
+ [
390
+ self.cluster_metrics.histogram.sum(0).cpu(),
391
+ self.cluster_metrics.histogram.sum(1).cpu(),
392
+ ],
393
+ axis=0,
394
+ )
395
+ ymin = max(all_bars.min() * 0.8, 1)
396
+ ymax = all_bars.max() * 1.2
397
+
398
+ fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 1 * 4))
399
+ ax[0].bar(
400
+ range(self.n_classes + self.cfg.extra_clusters),
401
+ self.cluster_metrics.histogram.sum(0).cpu(),
402
+ tick_label=names,
403
+ color=colors,
404
+ )
405
+ ax[0].set_ylim(ymin, ymax)
406
+ ax[0].set_title("Label Frequency")
407
+ ax[0].set_yscale("log")
408
+ ax[0].tick_params(axis="x", labelrotation=90)
409
+
410
+ ax[1].bar(
411
+ range(self.n_classes + self.cfg.extra_clusters),
412
+ self.cluster_metrics.histogram.sum(1).cpu(),
413
+ tick_label=names,
414
+ color=colors,
415
+ )
416
+ ax[1].set_ylim(ymin, ymax)
417
+ ax[1].set_title("Cluster Frequency")
418
+ ax[1].set_yscale("log")
419
+ ax[1].tick_params(axis="x", labelrotation=90)
420
+
421
+ plt.tight_layout()
422
+ add_plot(
423
+ self.logger.experiment, "label frequency", self.global_step
424
+ )
425
+
426
+ if self.global_step > 2:
427
+ self.log_dict(tb_metrics)
428
+
429
+ if self.trainer.is_global_zero and self.cfg.azureml_logging:
430
+ from azureml.core.run import Run
431
+
432
+ run_logger = Run.get_context()
433
+ for metric, value in tb_metrics.items():
434
+ run_logger.log(metric, value)
435
+
436
+ self.linear_metrics.reset()
437
+ self.cluster_metrics.reset()
438
+
439
+ def configure_optimizers(self):
440
+ main_params = list(self.net.parameters())
441
+
442
+ if self.cfg.rec_weight > 0:
443
+ main_params.extend(self.decoder.parameters())
444
+
445
+ net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr)
446
+ linear_probe_optim = torch.optim.Adam(
447
+ list(self.linear_probe.parameters()), lr=5e-3
448
+ )
449
+ cluster_probe_optim = torch.optim.Adam(
450
+ list(self.cluster_probe.parameters()), lr=5e-3
451
+ )
452
+
453
+ return net_optim, linear_probe_optim, cluster_probe_optim
biomap/modules.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from utils import *
4
+ import torch.nn.functional as F
5
+ import dino.vision_transformer as vits
6
+
7
+ import pdb
8
+
9
+ class LambdaLayer(nn.Module):
10
+ def __init__(self, lambd):
11
+ super(LambdaLayer, self).__init__()
12
+ self.lambd = lambd
13
+
14
+ def forward(self, x):
15
+ return self.lambd(x)
16
+
17
+
18
+ class DinoFeaturizer(nn.Module):
19
+
20
+ def __init__(self, dim, cfg):
21
+ super().__init__()
22
+ self.cfg = cfg
23
+ self.dim = dim
24
+ patch_size = self.cfg.dino_patch_size
25
+ self.patch_size = patch_size
26
+ self.feat_type = self.cfg.dino_feat_type
27
+ arch = self.cfg.model_type
28
+ self.model = vits.__dict__[arch](
29
+ patch_size=patch_size,
30
+ num_classes=0)
31
+ for p in self.model.parameters():
32
+ p.requires_grad = False
33
+ # pdb.set_trace()
34
+ self.model=self.model.cpu()
35
+ self.model.eval()
36
+ self.dropout = torch.nn.Dropout2d(p=.1)
37
+
38
+ if arch == "vit_small" and patch_size == 16:
39
+ url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
40
+ elif arch == "vit_small" and patch_size == 8:
41
+ url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
42
+ elif arch == "vit_base" and patch_size == 16:
43
+ url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
44
+ elif arch == "vit_base" and patch_size == 8:
45
+ url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
46
+ else:
47
+ raise ValueError("Unknown arch and patch size")
48
+
49
+ if cfg.pretrained_weights is not None:
50
+ state_dict = torch.load(cfg.pretrained_weights, map_location="cpu")
51
+ state_dict = state_dict["teacher"]
52
+ # remove `module.` prefix
53
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
54
+ # remove `backbone.` prefix induced by multicrop wrapper
55
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
56
+
57
+ # state_dict = {k.replace("projection_head", "mlp"): v for k, v in state_dict.items()}
58
+ # state_dict = {k.replace("prototypes", "last_layer"): v for k, v in state_dict.items()}
59
+
60
+ msg = self.model.load_state_dict(state_dict, strict=False)
61
+ print('Pretrained weights found at {} and loaded with msg: {}'.format(cfg.pretrained_weights, msg))
62
+ else:
63
+ print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
64
+ state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
65
+ self.model.load_state_dict(state_dict, strict=True)
66
+
67
+ if arch == "vit_small":
68
+ self.n_feats = 384
69
+ else:
70
+ self.n_feats = 768
71
+ self.cluster1 = self.make_clusterer(self.n_feats)
72
+ self.proj_type = cfg.projection_type
73
+ if self.proj_type == "nonlinear":
74
+ self.cluster2 = self.make_nonlinear_clusterer(self.n_feats)
75
+
76
+ def make_clusterer(self, in_channels):
77
+ return torch.nn.Sequential(
78
+ torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # ,
79
+
80
+ def make_nonlinear_clusterer(self, in_channels):
81
+ return torch.nn.Sequential(
82
+ torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
83
+ torch.nn.ReLU(),
84
+ torch.nn.Conv2d(in_channels, self.dim, (1, 1)))
85
+
86
+ def forward(self, img, n=1, return_class_feat=False):
87
+ self.model.eval()
88
+ with torch.no_grad():
89
+ assert (img.shape[2] % self.patch_size == 0)
90
+ assert (img.shape[3] % self.patch_size == 0)
91
+
92
+ # get selected layer activations
93
+ feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
94
+ feat, attn, qkv = feat[0], attn[0], qkv[0]
95
+
96
+ feat_h = img.shape[2] // self.patch_size
97
+ feat_w = img.shape[3] // self.patch_size
98
+
99
+ if self.feat_type == "feat":
100
+ image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
101
+ elif self.feat_type == "KK":
102
+ image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1)
103
+ B, H, I, J, D = image_k.shape
104
+ image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J)
105
+ else:
106
+ raise ValueError("Unknown feat type:{}".format(self.feat_type))
107
+
108
+ if return_class_feat:
109
+ return feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2)
110
+
111
+ if self.proj_type is not None:
112
+ code = self.cluster1(self.dropout(image_feat))
113
+ if self.proj_type == "nonlinear":
114
+ code += self.cluster2(self.dropout(image_feat))
115
+ else:
116
+ code = image_feat
117
+
118
+ if self.cfg.dropout:
119
+ return self.dropout(image_feat), code
120
+ else:
121
+ return image_feat, code
122
+
123
+
124
+ class ResizeAndClassify(nn.Module):
125
+
126
+ def __init__(self, dim: int, size: int, n_classes: int):
127
+ super(ResizeAndClassify, self).__init__()
128
+ self.size = size
129
+ self.predictor = torch.nn.Sequential(
130
+ torch.nn.Conv2d(dim, n_classes, (1, 1)),
131
+ torch.nn.LogSoftmax(1))
132
+
133
+ def forward(self, x):
134
+ return F.interpolate(self.predictor.forward(x), self.size, mode="bilinear", align_corners=False)
135
+
136
+
137
+ class ClusterLookup(nn.Module):
138
+
139
+ def __init__(self, dim: int, n_classes: int):
140
+ super(ClusterLookup, self).__init__()
141
+ self.n_classes = n_classes
142
+ self.dim = dim
143
+ self.clusters = torch.nn.Parameter(torch.randn(n_classes, dim))
144
+
145
+ def reset_parameters(self):
146
+ with torch.no_grad():
147
+ self.clusters.copy_(torch.randn(self.n_classes, self.dim))
148
+
149
+ def forward(self, x, alpha, log_probs=False):
150
+ normed_clusters = F.normalize(self.clusters, dim=1)
151
+ normed_features = F.normalize(x, dim=1)
152
+ inner_products = torch.einsum("bchw,nc->bnhw", normed_features, normed_clusters)
153
+
154
+ if alpha is None:
155
+ cluster_probs = F.one_hot(torch.argmax(inner_products, dim=1), self.clusters.shape[0]) \
156
+ .permute(0, 3, 1, 2).to(torch.float32)
157
+ else:
158
+ cluster_probs = nn.functional.softmax(inner_products * alpha, dim=1)
159
+
160
+ cluster_loss = -(cluster_probs * inner_products).sum(1).mean()
161
+ if log_probs:
162
+ return nn.functional.log_softmax(inner_products * alpha, dim=1)
163
+ else:
164
+ return cluster_loss, cluster_probs
165
+
166
+
167
+ class FeaturePyramidNet(nn.Module):
168
+
169
+ @staticmethod
170
+ def _helper(x):
171
+ # TODO remove this hard coded 56
172
+ return F.interpolate(x, 56, mode="bilinear", align_corners=False).unsqueeze(-1)
173
+
174
+ def make_clusterer(self, in_channels):
175
+ return torch.nn.Sequential(
176
+ torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
177
+ LambdaLayer(FeaturePyramidNet._helper))
178
+
179
+ def make_nonlinear_clusterer(self, in_channels):
180
+ return torch.nn.Sequential(
181
+ torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
182
+ torch.nn.ReLU(),
183
+ torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
184
+ torch.nn.ReLU(),
185
+ torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
186
+ LambdaLayer(FeaturePyramidNet._helper))
187
+
188
+ def __init__(self, granularity, cut_model, dim, continuous):
189
+ super(FeaturePyramidNet, self).__init__()
190
+ self.layer_nums = [5, 6, 7]
191
+ self.spatial_resolutions = [7, 14, 28, 56]
192
+ self.feat_channels = [2048, 1024, 512, 3]
193
+ self.extra_channels = [128, 64, 32, 32]
194
+ self.granularity = granularity
195
+ self.encoder = NetWithActivations(cut_model, self.layer_nums)
196
+ self.dim = dim
197
+ self.continuous = continuous
198
+ self.n_feats = self.dim
199
+
200
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
201
+
202
+ assert granularity in {1, 2, 3, 4}
203
+ self.cluster1 = self.make_clusterer(self.feat_channels[0])
204
+ self.cluster1_nl = self.make_nonlinear_clusterer(self.feat_channels[0])
205
+
206
+ if granularity >= 2:
207
+ # self.conv1 = DoubleConv(self.feat_channels[0], self.extra_channels[0])
208
+ # self.conv2 = DoubleConv(self.extra_channels[0] + self.feat_channels[1], self.extra_channels[1])
209
+ self.conv2 = DoubleConv(self.feat_channels[0] + self.feat_channels[1], self.extra_channels[1])
210
+ self.cluster2 = self.make_clusterer(self.extra_channels[1])
211
+ if granularity >= 3:
212
+ self.conv3 = DoubleConv(self.extra_channels[1] + self.feat_channels[2], self.extra_channels[2])
213
+ self.cluster3 = self.make_clusterer(self.extra_channels[2])
214
+ if granularity >= 4:
215
+ self.conv4 = DoubleConv(self.extra_channels[2] + self.feat_channels[3], self.extra_channels[3])
216
+ self.cluster4 = self.make_clusterer(self.extra_channels[3])
217
+
218
+ def c(self, x, y):
219
+ return torch.cat([x, y], dim=1)
220
+
221
+ def forward(self, x):
222
+ with torch.no_grad():
223
+ feats = self.encoder(x)
224
+ low_res_feats = feats[self.layer_nums[-1]]
225
+
226
+ all_clusters = []
227
+
228
+ # all_clusters.append(self.cluster1(low_res_feats) + self.cluster1_nl(low_res_feats))
229
+ all_clusters.append(self.cluster1(low_res_feats))
230
+
231
+ if self.granularity >= 2:
232
+ # f1 = self.conv1(low_res_feats)
233
+ # f1_up = self.up(f1)
234
+ f1_up = self.up(low_res_feats)
235
+ f2 = self.conv2(self.c(f1_up, feats[self.layer_nums[-2]]))
236
+ all_clusters.append(self.cluster2(f2))
237
+ if self.granularity >= 3:
238
+ f2_up = self.up(f2)
239
+ f3 = self.conv3(self.c(f2_up, feats[self.layer_nums[-3]]))
240
+ all_clusters.append(self.cluster3(f3))
241
+ if self.granularity >= 4:
242
+ f3_up = self.up(f3)
243
+ final_size = self.spatial_resolutions[-1]
244
+ f4 = self.conv4(self.c(f3_up, F.interpolate(
245
+ x, (final_size, final_size), mode="bilinear", align_corners=False)))
246
+ all_clusters.append(self.cluster4(f4))
247
+
248
+ avg_code = torch.cat(all_clusters, 4).mean(4)
249
+
250
+ if self.continuous:
251
+ clusters = avg_code
252
+ else:
253
+ clusters = torch.log_softmax(avg_code, 1)
254
+
255
+ return low_res_feats, clusters
256
+
257
+
258
+ class DoubleConv(nn.Module):
259
+ """(convolution => [BN] => ReLU) * 2"""
260
+
261
+ def __init__(self, in_channels, out_channels, mid_channels=None):
262
+ super().__init__()
263
+ if not mid_channels:
264
+ mid_channels = out_channels
265
+ self.double_conv = nn.Sequential(
266
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
267
+ nn.BatchNorm2d(mid_channels),
268
+ nn.ReLU(),
269
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
270
+ nn.BatchNorm2d(out_channels),
271
+ nn.ReLU()
272
+ )
273
+
274
+ def forward(self, x):
275
+ return self.double_conv(x)
276
+
277
+
278
+ def norm(t):
279
+ return F.normalize(t, dim=1, eps=1e-10)
280
+
281
+
282
+ def average_norm(t):
283
+ return t / t.square().sum(1, keepdim=True).sqrt().mean()
284
+
285
+
286
+ def tensor_correlation(a, b):
287
+ return torch.einsum("nchw,ncij->nhwij", a, b)
288
+
289
+
290
+ def sample(t: torch.Tensor, coords: torch.Tensor):
291
+ return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True)
292
+
293
+
294
+ @torch.jit.script
295
+ def super_perm(size: int, device: torch.device):
296
+ perm = torch.randperm(size, device=device, dtype=torch.long)
297
+ perm[perm == torch.arange(size, device=device)] += 1
298
+ return perm % size
299
+
300
+
301
+ def sample_nonzero_locations(t, target_size):
302
+ nonzeros = torch.nonzero(t)
303
+ coords = torch.zeros(target_size, dtype=nonzeros.dtype, device=nonzeros.device)
304
+ n = target_size[1] * target_size[2]
305
+ for i in range(t.shape[0]):
306
+ selected_nonzeros = nonzeros[nonzeros[:, 0] == i]
307
+ if selected_nonzeros.shape[0] == 0:
308
+ selected_coords = torch.randint(t.shape[1], size=(n, 2), device=nonzeros.device)
309
+ else:
310
+ selected_coords = selected_nonzeros[torch.randint(len(selected_nonzeros), size=(n,)), 1:]
311
+ coords[i, :, :, :] = selected_coords.reshape(target_size[1], target_size[2], 2)
312
+ coords = coords.to(torch.float32) / t.shape[1]
313
+ coords = coords * 2 - 1
314
+ return torch.flip(coords, dims=[-1])
315
+
316
+
317
+ class ContrastiveCorrelationLoss(nn.Module):
318
+
319
+ def __init__(self, cfg, ):
320
+ super(ContrastiveCorrelationLoss, self).__init__()
321
+ self.cfg = cfg
322
+
323
+ def standard_scale(self, t):
324
+ t1 = t - t.mean()
325
+ t2 = t1 / t1.std()
326
+ return t2
327
+
328
+ def helper(self, f1, f2, c1, c2, shift):
329
+ with torch.no_grad():
330
+ # Comes straight from backbone which is currently frozen. this saves mem.
331
+ fd = tensor_correlation(norm(f1), norm(f2))
332
+
333
+ if self.cfg.pointwise:
334
+ old_mean = fd.mean()
335
+ fd -= fd.mean([3, 4], keepdim=True)
336
+ fd = fd - fd.mean() + old_mean
337
+
338
+ cd = tensor_correlation(norm(c1), norm(c2))
339
+
340
+ if self.cfg.zero_clamp:
341
+ min_val = 0.0
342
+ else:
343
+ min_val = -9999.0
344
+
345
+ if self.cfg.stabalize:
346
+ loss = - cd.clamp(min_val, .8) * (fd - shift)
347
+ else:
348
+ loss = - cd.clamp(min_val) * (fd - shift)
349
+
350
+ return loss, cd
351
+
352
+ def forward(self,
353
+ orig_feats: torch.Tensor, orig_feats_pos: torch.Tensor,
354
+ orig_salience: torch.Tensor, orig_salience_pos: torch.Tensor,
355
+ orig_code: torch.Tensor, orig_code_pos: torch.Tensor,
356
+ ):
357
+
358
+ coord_shape = [orig_feats.shape[0], self.cfg.feature_samples, self.cfg.feature_samples, 2]
359
+
360
+ if self.cfg.use_salience:
361
+ coords1_nonzero = sample_nonzero_locations(orig_salience, coord_shape)
362
+ coords2_nonzero = sample_nonzero_locations(orig_salience_pos, coord_shape)
363
+ coords1_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
364
+ coords2_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
365
+ mask = (torch.rand(coord_shape[:-1], device=orig_feats.device) > .1).unsqueeze(-1).to(torch.float32)
366
+ coords1 = coords1_nonzero * mask + coords1_reg * (1 - mask)
367
+ coords2 = coords2_nonzero * mask + coords2_reg * (1 - mask)
368
+ else:
369
+ coords1 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
370
+ coords2 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
371
+
372
+ feats = sample(orig_feats, coords1)
373
+ code = sample(orig_code, coords1)
374
+
375
+ feats_pos = sample(orig_feats_pos, coords2)
376
+ code_pos = sample(orig_code_pos, coords2)
377
+
378
+ pos_intra_loss, pos_intra_cd = self.helper(
379
+ feats, feats, code, code, self.cfg.pos_intra_shift)
380
+ pos_inter_loss, pos_inter_cd = self.helper(
381
+ feats, feats_pos, code, code_pos, self.cfg.pos_inter_shift)
382
+
383
+ neg_losses = []
384
+ neg_cds = []
385
+ for i in range(self.cfg.neg_samples):
386
+ perm_neg = super_perm(orig_feats.shape[0], orig_feats.device)
387
+ feats_neg = sample(orig_feats[perm_neg], coords2)
388
+ code_neg = sample(orig_code[perm_neg], coords2)
389
+ neg_inter_loss, neg_inter_cd = self.helper(
390
+ feats, feats_neg, code, code_neg, self.cfg.neg_inter_shift)
391
+ neg_losses.append(neg_inter_loss)
392
+ neg_cds.append(neg_inter_cd)
393
+ neg_inter_loss = torch.cat(neg_losses, axis=0)
394
+ neg_inter_cd = torch.cat(neg_cds, axis=0)
395
+
396
+ return (pos_intra_loss.mean(),
397
+ pos_intra_cd,
398
+ pos_inter_loss.mean(),
399
+ pos_inter_cd,
400
+ neg_inter_loss,
401
+ neg_inter_cd)
402
+
403
+
404
+ class Decoder(nn.Module):
405
+ def __init__(self, code_channels, feat_channels):
406
+ super().__init__()
407
+ self.linear = torch.nn.Conv2d(code_channels, feat_channels, (1, 1))
408
+ self.nonlinear = torch.nn.Sequential(
409
+ torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
410
+ torch.nn.ReLU(),
411
+ torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
412
+ torch.nn.ReLU(),
413
+ torch.nn.Conv2d(code_channels, feat_channels, (1, 1)))
414
+
415
+ def forward(self, x):
416
+ return self.linear(x) + self.nonlinear(x)
417
+
418
+
419
+ class NetWithActivations(torch.nn.Module):
420
+ def __init__(self, model, layer_nums):
421
+ super(NetWithActivations, self).__init__()
422
+ self.layers = nn.ModuleList(model.children())
423
+ self.layer_nums = []
424
+ for l in layer_nums:
425
+ if l < 0:
426
+ self.layer_nums.append(len(self.layers) + l)
427
+ else:
428
+ self.layer_nums.append(l)
429
+ self.layer_nums = set(sorted(self.layer_nums))
430
+
431
+ def forward(self, x):
432
+ activations = {}
433
+ for ln, l in enumerate(self.layers):
434
+ x = l(x)
435
+ if ln in self.layer_nums:
436
+ activations[ln] = x
437
+ return activations
438
+
439
+
440
+ class ContrastiveCRFLoss(nn.Module):
441
+
442
+ def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift):
443
+ super(ContrastiveCRFLoss, self).__init__()
444
+ self.alpha = alpha
445
+ self.beta = beta
446
+ self.gamma = gamma
447
+ self.w1 = w1
448
+ self.w2 = w2
449
+ self.n_samples = n_samples
450
+ self.shift = shift
451
+
452
+ def forward(self, guidance, clusters):
453
+ device = clusters.device
454
+ assert (guidance.shape[0] == clusters.shape[0])
455
+ assert (guidance.shape[2:] == clusters.shape[2:])
456
+ h = guidance.shape[2]
457
+ w = guidance.shape[3]
458
+
459
+ coords = torch.cat([
460
+ torch.randint(0, h, size=[1, self.n_samples], device=device),
461
+ torch.randint(0, w, size=[1, self.n_samples], device=device)], 0)
462
+
463
+ selected_guidance = guidance[:, :, coords[0, :], coords[1, :]]
464
+ coord_diff = (coords.unsqueeze(-1) - coords.unsqueeze(1)).square().sum(0).unsqueeze(0)
465
+ guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(2)).square().sum(1)
466
+
467
+ sim_kernel = self.w1 * torch.exp(- coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) + \
468
+ self.w2 * torch.exp(- coord_diff / (2 * self.gamma)) - self.shift
469
+
470
+ selected_clusters = clusters[:, :, coords[0, :], coords[1, :]]
471
+ cluster_sims = torch.einsum("nka,nkb->nab", selected_clusters, selected_clusters)
472
+ return -(cluster_sims * sim_kernel)
biomap/output/img.png ADDED
biomap/output/img_6.png ADDED
biomap/output/label.png ADDED
biomap/output/labeled_img.png ADDED
biomap/plot_functions.py ADDED
@@ -0,0 +1,778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import hydra
4
+ import matplotlib as mpl
5
+ from utils import prep_for_plot
6
+
7
+ import torch.multiprocessing
8
+ import torchvision.transforms as T
9
+ # import matplotlib.pyplot as plt
10
+ from model import LitUnsupervisedSegmenter
11
+ colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
12
+ class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
13
+ cmap = mpl.colors.ListedColormap(colors)
14
+ #from train_segmentation import LitUnsupervisedSegmenter, cmap
15
+
16
+ from utils_gee import extract_img, transform_ee_img
17
+
18
+ import plotly.graph_objects as go
19
+ import plotly.express as px
20
+ import numpy as np
21
+ from plotly.subplots import make_subplots
22
+
23
+ import os
24
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
25
+
26
+
27
+ colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
28
+ class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
29
+ scores_init = [2,3,4,3,1,4,0]
30
+
31
+ # Import model configs
32
+ hydra.initialize(config_path="configs", job_name="corine")
33
+ cfg = hydra.compose(config_name="my_train_config.yml")
34
+
35
+ nbclasses = cfg.dir_dataset_n_classes
36
+
37
+ # Load Model
38
+ model_path = "checkpoint/model/model.pt"
39
+ saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
40
+
41
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
42
+ model.load_state_dict(saved_state_dict)
43
+
44
+ from PIL import Image
45
+
46
+ import hydra
47
+
48
+ from utils import prep_for_plot
49
+
50
+ import torch.multiprocessing
51
+ import torchvision.transforms as T
52
+ # import matplotlib.pyplot as plt
53
+
54
+ from model import LitUnsupervisedSegmenter
55
+
56
+ from utils_gee import extract_img, transform_ee_img
57
+
58
+ import plotly.graph_objects as go
59
+ import plotly.express as px
60
+ import numpy as np
61
+ from plotly.subplots import make_subplots
62
+
63
+ import os
64
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
65
+
66
+
67
+ colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
68
+ cmap = mpl.colors.ListedColormap(colors)
69
+ class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
70
+ scores_init = [2,3,4,3,1,4,0]
71
+
72
+ # Import model configs
73
+ #hydra.initialize(config_path="configs", job_name="corine")
74
+ cfg = hydra.compose(config_name="my_train_config.yml")
75
+
76
+ nbclasses = cfg.dir_dataset_n_classes
77
+
78
+ # Load Model
79
+ model_path = "checkpoint/model/model.pt"
80
+ saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
81
+
82
+ model = LitUnsupervisedSegmenter(nbclasses, cfg)
83
+ model.load_state_dict(saved_state_dict)
84
+
85
+
86
+ #normalize img
87
+ preprocess = T.Compose([
88
+ T.ToPILImage(),
89
+ T.Resize((320,320)),
90
+ # T.CenterCrop(224),
91
+ T.ToTensor(),
92
+ T.Normalize(
93
+ mean=[0.485, 0.456, 0.406],
94
+ std=[0.229, 0.224, 0.225]
95
+ )
96
+ ])
97
+
98
+ # Function that look for img on EE and segment it
99
+ # -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
100
+
101
+ def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
102
+ if how == 'month':
103
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
104
+ elif how == 'year' :
105
+ if year_end == None :
106
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
107
+ else :
108
+ img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
109
+
110
+
111
+ img_test= transform_ee_img(img, max = 0.25)
112
+
113
+ # Preprocess opened img
114
+ x = preprocess(img_test)
115
+ x = torch.unsqueeze(x, dim=0).cpu()
116
+ # model=model.cpu()
117
+
118
+ with torch.no_grad():
119
+ feats, code = model.net(x)
120
+ linear_preds = model.linear_probe(x, code)
121
+ linear_preds = linear_preds.argmax(1)
122
+ outputs = {
123
+ 'img': x[:model.cfg.n_images].detach().cpu(),
124
+ 'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
125
+ }
126
+ return outputs
127
+
128
+
129
+ # Function that look for all img on EE and extract all segments with the date as first output arg
130
+
131
+ def segment_group(location, start_date, end_date, how = 'month') :
132
+ outputs = []
133
+ st_month = int(start_date[5:7])
134
+ end_month = int(end_date[5:7])
135
+
136
+ st_year = int(start_date[0:4])
137
+ end_year = int(end_date[0:4])
138
+
139
+
140
+
141
+ for year in range(st_year, end_year+1) :
142
+
143
+ if year != end_year :
144
+ last = 12
145
+ else :
146
+ last = end_month
147
+
148
+ if year != st_year:
149
+ start = 1
150
+ else :
151
+ start = st_month
152
+
153
+ if how == 'month' :
154
+ for month in range(start, last + 1):
155
+ month_str = f"{month:0>2d}"
156
+ year_str = str(year)
157
+
158
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
159
+
160
+ elif how == 'year' :
161
+ outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
162
+
163
+ elif how == '2months' :
164
+ for month in range(start, last + 1):
165
+ month_str = f"{month:0>2d}"
166
+ year_str = str(year)
167
+ month_end = (month) % 12 +1
168
+ if month_end < month :
169
+ year_end = year +1
170
+ else :
171
+ year_end = year
172
+ month_end= f"{month_end:0>2d}"
173
+ year_end = str(year_end)
174
+
175
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
176
+
177
+
178
+ return outputs
179
+
180
+
181
+ # Function that transforms an output to PIL images
182
+
183
+ def transform_to_pil(outputs,alpha=0.3):
184
+ # Transform img with torch
185
+ img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
186
+ img=T.ToPILImage()(img)
187
+
188
+ # Transform label by saving it then open it
189
+ # label = outputs['linear_preds'][0]
190
+ # plt.imsave('label.png',label,cmap=cmap)
191
+ # label = Image.open('label.png')
192
+
193
+ cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
194
+ labels = np.array(outputs['linear_preds'][0])-1
195
+ label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
196
+
197
+
198
+ # Overlay labels with img wit alpha
199
+ background = img.convert("RGBA")
200
+ overlay = label.convert("RGBA")
201
+
202
+ labeled_img = Image.blend(background, overlay, alpha)
203
+
204
+ return img, label, labeled_img
205
+
206
+
207
+
208
+ # Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
209
+
210
+ def values_from_output(output):
211
+ imgs = transform_to_pil(output,alpha = 0.3)
212
+
213
+ img = imgs[0]
214
+ img = np.array(img.convert('RGB'))
215
+
216
+ labeled_img = imgs[2]
217
+ labeled_img = np.array(labeled_img.convert('RGB'))
218
+
219
+ nb_values = []
220
+ for i in range(7):
221
+ nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
222
+
223
+ score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
224
+
225
+ return img, labeled_img, nb_values, score
226
+
227
+
228
+ # Function that extract from outputs (from segment_group function) all dates/ all images
229
+ def values_from_outputs(outputs) :
230
+ months = []
231
+ imgs = []
232
+ imgs_label = []
233
+ nb_values = []
234
+ scores = []
235
+
236
+ for output in outputs:
237
+ img, labeled_img, nb_value, score = values_from_output(output[1])
238
+ months.append(output[0])
239
+ imgs.append(img)
240
+ imgs_label.append(labeled_img)
241
+ nb_values.append(nb_value)
242
+ scores.append(score)
243
+
244
+ return months, imgs, imgs_label, nb_values, scores
245
+
246
+
247
+
248
+ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
249
+
250
+ fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
251
+ fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
252
+
253
+ # Scores
254
+ scatters = []
255
+ temp = []
256
+ for score in scores :
257
+ temp_score = []
258
+ temp_date = []
259
+ score = scores[i]
260
+ temp.append(score)
261
+ text_temp = ["" for i in temp]
262
+ text_temp[-1] = str(round(score,2))
263
+ scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
264
+
265
+
266
+ # Scores
267
+ fig = make_subplots(
268
+ rows=1, cols=4,
269
+ # specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]]
270
+ # row_heights=[0.8, 0.2],
271
+ column_widths = [0.6, 0.6,0.3, 0.3],
272
+ subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores")
273
+ )
274
+
275
+ fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
276
+ fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
277
+
278
+ fig.add_trace(go.Pie(labels = class_names,
279
+ values = nb_values[0],
280
+ marker_colors = colors,
281
+ name="Segment repartition",
282
+ textposition='inside',
283
+ texttemplate = "%{percent:.0%}",
284
+ textfont_size=14
285
+ ),
286
+ row=1, col=3)
287
+
288
+
289
+ fig.add_trace(scatters[0], row=1, col=4)
290
+ # fig.add_annotation(text='score:' + str(scores[0]),
291
+ # showarrow=False,
292
+ # row=2, col=2)
293
+
294
+
295
+ number_frames = len(imgs)
296
+ frames = [dict(
297
+ name = k,
298
+ data = [ fig2["frames"][k]["data"][0],
299
+ fig3["frames"][k]["data"][0],
300
+ go.Pie(labels = class_names,
301
+ values = nb_values[k],
302
+ marker_colors = colors,
303
+ name="Segment repartition",
304
+ textposition='inside',
305
+ texttemplate = "%{percent:.0%}",
306
+ textfont_size=14
307
+ ),
308
+ scatters[k]
309
+ ],
310
+ traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
311
+ # that are updated by the above three go.Scatter instances
312
+ ) for k in range(number_frames)]
313
+
314
+ updatemenus = [dict(type='buttons',
315
+ buttons=[dict(label='Play',
316
+ method='animate',
317
+ args=[[f'{k}' for k in range(number_frames)],
318
+ dict(frame=dict(duration=500, redraw=False),
319
+ transition=dict(duration=0),
320
+ easing='linear',
321
+ fromcurrent=True,
322
+ mode='immediate'
323
+ )])],
324
+ direction= 'left',
325
+ pad=dict(r= 10, t=85),
326
+ showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
327
+ ]
328
+
329
+ sliders = [{'yanchor': 'top',
330
+ 'xanchor': 'left',
331
+ 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
332
+ 'transition': {'duration': 500.0, 'easing': 'linear'},
333
+ 'pad': {'b': 10, 't': 50},
334
+ 'len': 0.9, 'x': 0.1, 'y': 0,
335
+ 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
336
+ 'transition': {'duration': 0, 'easing': 'linear'}}],
337
+ 'label': months[k], 'method': 'animate'} for k in range(number_frames)
338
+ ]}]
339
+
340
+
341
+ fig.update(frames=frames)
342
+
343
+ for i,fr in enumerate(fig["frames"]):
344
+ fr.update(
345
+ layout={
346
+ "xaxis": {
347
+ "range": [0,imgs[0].shape[1]+i/100000]
348
+ },
349
+ "yaxis": {
350
+ "range": [imgs[0].shape[0]+i/100000,0]
351
+ },
352
+ })
353
+
354
+ fr.update(layout_title_text= months[i])
355
+
356
+
357
+ fig.update(layout_title_text= 'tot')
358
+ fig.update(
359
+ layout={
360
+ "xaxis": {
361
+ "range": [0,imgs[0].shape[1]+i/100000],
362
+ 'showgrid': False, # thin lines in the background
363
+ 'zeroline': False, # thick line at x=0
364
+ 'visible': False, # numbers below
365
+ },
366
+
367
+ "yaxis": {
368
+ "range": [imgs[0].shape[0]+i/100000,0],
369
+ 'showgrid': False, # thin lines in the background
370
+ 'zeroline': False, # thick line at y=0
371
+ 'visible': False,},
372
+
373
+ "xaxis3": {
374
+ "range": [0,len(scores)+1],
375
+ 'autorange': False, # thin lines in the background
376
+ 'showgrid': False, # thin lines in the background
377
+ 'zeroline': False, # thick line at y=0
378
+ 'visible': False
379
+ },
380
+
381
+ "yaxis3": {
382
+ "range": [0,1.5],
383
+ 'autorange': False,
384
+ 'showgrid': False, # thin lines in the background
385
+ 'zeroline': False, # thick line at y=0
386
+ 'visible': False # thin lines in the background
387
+ }
388
+ },
389
+ legend=dict(
390
+ yanchor="bottom",
391
+ y=0.99,
392
+ xanchor="center",
393
+ x=0.01
394
+ )
395
+ )
396
+
397
+
398
+ fig.update_layout(updatemenus=updatemenus,
399
+ sliders=sliders)
400
+
401
+ fig.update_layout(margin=dict(b=0, r=0))
402
+
403
+ # fig.show() #in jupyter notebook
404
+
405
+ return fig
406
+
407
+
408
+
409
+ # Last function (global one)
410
+ # how = 'month' or '2months' or 'year'
411
+
412
+ def segment_region(location, start_date, end_date, how = 'month'):
413
+
414
+ #extract the outputs for each image
415
+ outputs = segment_group(location, start_date, end_date, how = how)
416
+
417
+ #extract the intersting values from image
418
+ months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
419
+
420
+ #Create the figure
421
+ fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
422
+
423
+ return fig
424
+ #normalize img
425
+ preprocess = T.Compose([
426
+ T.ToPILImage(),
427
+ T.Resize((320,320)),
428
+ # T.CenterCrop(224),
429
+ T.ToTensor(),
430
+ T.Normalize(
431
+ mean=[0.485, 0.456, 0.406],
432
+ std=[0.229, 0.224, 0.225]
433
+ )
434
+ ])
435
+
436
+ # Function that look for img on EE and segment it
437
+ # -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
438
+
439
+ def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
440
+ if how == 'month':
441
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
442
+ elif how == 'year' :
443
+ if year_end == None :
444
+ img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
445
+ else :
446
+ img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
447
+
448
+
449
+ img_test= transform_ee_img(img, max = 0.25)
450
+
451
+ # Preprocess opened img
452
+ x = preprocess(img_test)
453
+ x = torch.unsqueeze(x, dim=0).cpu()
454
+ # model=model.cpu()
455
+
456
+ with torch.no_grad():
457
+ feats, code = model.net(x)
458
+ linear_preds = model.linear_probe(x, code)
459
+ linear_preds = linear_preds.argmax(1)
460
+ outputs = {
461
+ 'img': x[:model.cfg.n_images].detach().cpu(),
462
+ 'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
463
+ }
464
+ return outputs
465
+
466
+
467
+ # Function that look for all img on EE and extract all segments with the date as first output arg
468
+
469
+ def segment_group(location, start_date, end_date, how = 'month') :
470
+ outputs = []
471
+ st_month = int(start_date[5:7])
472
+ end_month = int(end_date[5:7])
473
+
474
+ st_year = int(start_date[0:4])
475
+ end_year = int(end_date[0:4])
476
+
477
+
478
+
479
+ for year in range(st_year, end_year+1) :
480
+
481
+ if year != end_year :
482
+ last = 12
483
+ else :
484
+ last = end_month
485
+
486
+ if year != st_year:
487
+ start = 1
488
+ else :
489
+ start = st_month
490
+
491
+ if how == 'month' :
492
+ for month in range(start, last + 1):
493
+ month_str = f"{month:0>2d}"
494
+ year_str = str(year)
495
+
496
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
497
+
498
+ elif how == 'year' :
499
+ outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
500
+
501
+ elif how == '2months' :
502
+ for month in range(start, last + 1):
503
+ month_str = f"{month:0>2d}"
504
+ year_str = str(year)
505
+ month_end = (month) % 12 +1
506
+ if month_end < month :
507
+ year_end = year +1
508
+ else :
509
+ year_end = year
510
+ month_end= f"{month_end:0>2d}"
511
+ year_end = str(year_end)
512
+
513
+ outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
514
+
515
+
516
+ return outputs
517
+
518
+
519
+ # Function that transforms an output to PIL images
520
+
521
+ def transform_to_pil(outputs,alpha=0.3):
522
+ # Transform img with torch
523
+ img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
524
+ img=T.ToPILImage()(img)
525
+
526
+ # Transform label by saving it then open it
527
+ # label = outputs['linear_preds'][0]
528
+ # plt.imsave('label.png',label,cmap=cmap)
529
+ # label = Image.open('label.png')
530
+
531
+ cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
532
+ labels = np.array(outputs['linear_preds'][0])-1
533
+ label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
534
+
535
+
536
+ # Overlay labels with img wit alpha
537
+ background = img.convert("RGBA")
538
+ overlay = label.convert("RGBA")
539
+
540
+ labeled_img = Image.blend(background, overlay, alpha)
541
+
542
+ return img, label, labeled_img
543
+
544
+ def values_from_output(output):
545
+ imgs = transform_to_pil(output,alpha = 0.3)
546
+
547
+ img = imgs[0]
548
+ img = np.array(img.convert('RGB'))
549
+
550
+ labeled_img = imgs[2]
551
+ labeled_img = np.array(labeled_img.convert('RGB'))
552
+
553
+ nb_values = []
554
+ for i in range(7):
555
+ nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
556
+
557
+ score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
558
+
559
+ return img, labeled_img, nb_values, score
560
+
561
+
562
+ # Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
563
+
564
+
565
+
566
+ # Function that extract from outputs (from segment_group function) all dates/ all images
567
+ def values_from_outputs(outputs) :
568
+ months = []
569
+ imgs = []
570
+ imgs_label = []
571
+ nb_values = []
572
+ scores = []
573
+
574
+ for output in outputs:
575
+ img, labeled_img, nb_value, score = values_from_output(output[1])
576
+ months.append(output[0])
577
+ imgs.append(img)
578
+ imgs_label.append(labeled_img)
579
+ nb_values.append(nb_value)
580
+ scores.append(score)
581
+
582
+ return months, imgs, imgs_label, nb_values, scores
583
+
584
+
585
+
586
+ def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
587
+
588
+ fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
589
+ fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
590
+
591
+ # Scores
592
+ scatters = []
593
+ temp = []
594
+ for score in scores :
595
+ temp_score = []
596
+ temp_date = []
597
+ #score = scores[i]
598
+ temp.append(score)
599
+ n = len(temp)
600
+ text_temp = ["" for i in temp]
601
+ text_temp[-1] = str(round(score,2))
602
+ scatters.append(go.Scatter(x=[0,1], y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
603
+ print(text_temp)
604
+
605
+ # Scores
606
+ fig = make_subplots(
607
+ rows=1, cols=4,
608
+ specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
609
+ subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
610
+ )
611
+
612
+ fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
613
+ fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
614
+
615
+ fig.add_trace(go.Pie(labels = class_names,
616
+ values = nb_values[0],
617
+ marker_colors = colors,
618
+ name="Segment repartition",
619
+ textposition='inside',
620
+ texttemplate = "%{percent:.0%}",
621
+ textfont_size=14
622
+ ),
623
+ row=1, col=3)
624
+
625
+
626
+ fig.add_trace(scatters[0], row=1, col=4)
627
+ fig.update_traces(showlegend=False, selector=dict(type='scatter'))
628
+ #fig.update_traces(, selector=dict(type='scatter'))
629
+ # fig.add_annotation(text='score:' + str(scores[0]),
630
+ # showarrow=False,
631
+ # row=2, col=2)
632
+
633
+
634
+ number_frames = len(imgs)
635
+ frames = [dict(
636
+ name = k,
637
+ data = [ fig2["frames"][k]["data"][0],
638
+ fig3["frames"][k]["data"][0],
639
+ go.Pie(labels = class_names,
640
+ values = nb_values[k],
641
+ marker_colors = colors,
642
+ name="Segment repartition",
643
+ textposition='inside',
644
+ texttemplate = "%{percent:.0%}",
645
+ textfont_size=14
646
+ ),
647
+ scatters[k]
648
+ ],
649
+ traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
650
+ # that are updated by the above three go.Scatter instances
651
+ ) for k in range(number_frames)]
652
+
653
+ updatemenus = [dict(type='buttons',
654
+ buttons=[dict(label='Play',
655
+ method='animate',
656
+ args=[[f'{k}' for k in range(number_frames)],
657
+ dict(frame=dict(duration=500, redraw=False),
658
+ transition=dict(duration=0),
659
+ easing='linear',
660
+ fromcurrent=True,
661
+ mode='immediate'
662
+ )])],
663
+ direction= 'left',
664
+ pad=dict(r= 10, t=85),
665
+ showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
666
+ ]
667
+
668
+ sliders = [{'yanchor': 'top',
669
+ 'xanchor': 'left',
670
+ 'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
671
+ 'transition': {'duration': 500.0, 'easing': 'linear'},
672
+ 'pad': {'b': 10, 't': 50},
673
+ 'len': 0.9, 'x': 0.1, 'y': 0,
674
+ 'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
675
+ 'transition': {'duration': 0, 'easing': 'linear'}}],
676
+ 'label': months[k], 'method': 'animate'} for k in range(number_frames)
677
+ ]}]
678
+
679
+
680
+ fig.update(frames=frames)
681
+
682
+ for i,fr in enumerate(fig["frames"]):
683
+ fr.update(
684
+ layout={
685
+ "xaxis": {
686
+ "range": [0,imgs[0].shape[1]+i/100000]
687
+ },
688
+ "yaxis": {
689
+ "range": [imgs[0].shape[0]+i/100000,0]
690
+ },
691
+ })
692
+
693
+ fr.update(layout_title_text= months[i])
694
+
695
+
696
+ fig.update(layout_title_text= months[0])
697
+ fig.update(
698
+ layout={
699
+ "xaxis": {
700
+ "range": [0,imgs[0].shape[1]+i/100000],
701
+ 'showgrid': False, # thin lines in the background
702
+ 'zeroline': False, # thick line at x=0
703
+ 'visible': False, # numbers below
704
+ },
705
+
706
+ "yaxis": {
707
+ "range": [imgs[0].shape[0]+i/100000,0],
708
+ 'showgrid': False, # thin lines in the background
709
+ 'zeroline': False, # thick line at y=0
710
+ 'visible': False,},
711
+
712
+ "xaxis2": {
713
+ "range": [0,imgs[0].shape[1]+i/100000],
714
+ 'showgrid': False, # thin lines in the background
715
+ 'zeroline': False, # thick line at x=0
716
+ 'visible': False, # numbers below
717
+ },
718
+
719
+ "yaxis2": {
720
+ "range": [imgs[0].shape[0]+i/100000,0],
721
+ 'showgrid': False, # thin lines in the background
722
+ 'zeroline': False, # thick line at y=0
723
+ 'visible': False,},
724
+
725
+
726
+ "xaxis3": {
727
+ "range": [0,len(scores)+1],
728
+ 'autorange': False, # thin lines in the background
729
+ 'showgrid': False, # thin lines in the background
730
+ 'zeroline': False, # thick line at y=0
731
+ 'visible': False
732
+ },
733
+
734
+ "yaxis3": {
735
+ "range": [0,1.5],
736
+ 'autorange': False,
737
+ 'showgrid': False, # thin lines in the background
738
+ 'zeroline': False, # thick line at y=0
739
+ 'visible': False # thin lines in the background
740
+ }
741
+ }
742
+ )
743
+
744
+
745
+ fig.update_layout(updatemenus=updatemenus,
746
+ sliders=sliders,
747
+ legend=dict(
748
+ yanchor= 'top',
749
+ xanchor= 'left',
750
+ orientation="h")
751
+ )
752
+
753
+
754
+ fig.update_layout(margin=dict(b=0, r=0))
755
+
756
+ # fig.show() #in jupyter notebook
757
+
758
+ return fig
759
+
760
+
761
+
762
+ # Last function (global one)
763
+ # how = 'month' or '2months' or 'year'
764
+
765
+ def segment_region(latitude, longitude, start_date, end_date, how = 'month'):
766
+ location = [float(latitude),float(longitude)]
767
+ how = how[0]
768
+ #extract the outputs for each image
769
+ outputs = segment_group(location, start_date, end_date, how = how)
770
+
771
+ #extract the intersting values from image
772
+ months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
773
+
774
+
775
+ #Create the figure
776
+ fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
777
+
778
+ return fig
biomap/train.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import *
2
+ from modules import *
3
+ from data import *
4
+ from torch.utils.data import DataLoader
5
+ import torch.nn.functional as F
6
+ from datetime import datetime
7
+ import hydra
8
+ from omegaconf import DictConfig, OmegaConf
9
+ import pytorch_lightning as pl
10
+ from pytorch_lightning import Trainer
11
+ from pytorch_lightning.loggers import TensorBoardLogger
12
+ from pytorch_lightning.utilities.seed import seed_everything
13
+ import torch.multiprocessing
14
+ import seaborn as sns
15
+ from pytorch_lightning.callbacks import ModelCheckpoint
16
+ import sys
17
+ import pdb
18
+ import matplotlib as mpl
19
+ from skimage import measure
20
+ from scipy.stats import mode as statsmode
21
+ from collections import OrderedDict
22
+ import unet
23
+ import pdb
24
+
25
+ torch.multiprocessing.set_sharing_strategy("file_system")
26
+ colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
27
+ class_names = (
28
+ "Buildings",
29
+ "Cultivation",
30
+ "Natural green",
31
+ "Wetland",
32
+ "Water",
33
+ "Infrastructure",
34
+ "Background",
35
+ )
36
+ bounds = list(np.arange(len(class_names) + 1) + 1)
37
+ cmap = mpl.colors.ListedColormap(colors)
38
+ norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
39
+
40
+
41
+ def retouch_label(pred_label, true_label):
42
+ retouched_label = pred_label + 0
43
+ blobs = measure.label(retouched_label)
44
+ for idx in np.unique(blobs):
45
+ # most frequent label class in this blob
46
+ retouched_label[blobs == idx] = statsmode(true_label[blobs == idx])[0][0]
47
+ return retouched_label
48
+
49
+
50
+ def get_class_labels(dataset_name):
51
+ if dataset_name.startswith("cityscapes"):
52
+ return [
53
+ "road",
54
+ "sidewalk",
55
+ "parking",
56
+ "rail track",
57
+ "building",
58
+ "wall",
59
+ "fence",
60
+ "guard rail",
61
+ "bridge",
62
+ "tunnel",
63
+ "pole",
64
+ "polegroup",
65
+ "traffic light",
66
+ "traffic sign",
67
+ "vegetation",
68
+ "terrain",
69
+ "sky",
70
+ "person",
71
+ "rider",
72
+ "car",
73
+ "truck",
74
+ "bus",
75
+ "caravan",
76
+ "trailer",
77
+ "train",
78
+ "motorcycle",
79
+ "bicycle",
80
+ ]
81
+ elif dataset_name == "cocostuff27":
82
+ return [
83
+ "electronic",
84
+ "appliance",
85
+ "food",
86
+ "furniture",
87
+ "indoor",
88
+ "kitchen",
89
+ "accessory",
90
+ "animal",
91
+ "outdoor",
92
+ "person",
93
+ "sports",
94
+ "vehicle",
95
+ "ceiling",
96
+ "floor",
97
+ "food",
98
+ "furniture",
99
+ "rawmaterial",
100
+ "textile",
101
+ "wall",
102
+ "window",
103
+ "building",
104
+ "ground",
105
+ "plant",
106
+ "sky",
107
+ "solid",
108
+ "structural",
109
+ "water",
110
+ ]
111
+ elif dataset_name == "voc":
112
+ return [
113
+ "background",
114
+ "aeroplane",
115
+ "bicycle",
116
+ "bird",
117
+ "boat",
118
+ "bottle",
119
+ "bus",
120
+ "car",
121
+ "cat",
122
+ "chair",
123
+ "cow",
124
+ "diningtable",
125
+ "dog",
126
+ "horse",
127
+ "motorbike",
128
+ "person",
129
+ "pottedplant",
130
+ "sheep",
131
+ "sofa",
132
+ "train",
133
+ "tvmonitor",
134
+ ]
135
+ elif dataset_name == "potsdam":
136
+ return ["roads and cars", "buildings and clutter", "trees and vegetation"]
137
+ else:
138
+ raise ValueError("Unknown Dataset {}".format(dataset_name))
139
+
140
+
141
+ @hydra.main(config_path="configs", config_name="train_config.yml")
142
+ def my_app(cfg: DictConfig) -> None:
143
+ OmegaConf.set_struct(cfg, False)
144
+ print(OmegaConf.to_yaml(cfg))
145
+ pytorch_data_dir = cfg.pytorch_data_dir
146
+ data_dir = join(cfg.output_root, "data")
147
+ log_dir = join(cfg.output_root, "logs")
148
+ checkpoint_dir = join(cfg.output_root, "checkpoints")
149
+
150
+ prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name)
151
+ name = "{}_date_{}".format(prefix, datetime.now().strftime("%b%d_%H-%M-%S"))
152
+ cfg.full_name = prefix
153
+
154
+ os.makedirs(data_dir, exist_ok=True)
155
+ os.makedirs(log_dir, exist_ok=True)
156
+ os.makedirs(checkpoint_dir, exist_ok=True)
157
+
158
+ seed_everything(seed=0)
159
+
160
+ print(data_dir)
161
+ print(cfg.output_root)
162
+
163
+ geometric_transforms = T.Compose(
164
+ [T.RandomHorizontalFlip(), T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))]
165
+ )
166
+ photometric_transforms = T.Compose(
167
+ [
168
+ T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
169
+ T.RandomGrayscale(0.2),
170
+ T.RandomApply([T.GaussianBlur((5, 5))]),
171
+ ]
172
+ )
173
+
174
+ sys.stdout.flush()
175
+
176
+ train_dataset = ContrastiveSegDataset(
177
+ pytorch_data_dir=pytorch_data_dir,
178
+ dataset_name=cfg.dataset_name,
179
+ crop_type=cfg.crop_type,
180
+ image_set="train",
181
+ transform=get_transform(cfg.res, False, cfg.loader_crop_type),
182
+ target_transform=get_transform(cfg.res, True, cfg.loader_crop_type),
183
+ cfg=cfg,
184
+ aug_geometric_transform=geometric_transforms,
185
+ aug_photometric_transform=photometric_transforms,
186
+ num_neighbors=cfg.num_neighbors,
187
+ mask=True,
188
+ pos_images=True,
189
+ pos_labels=True,
190
+ )
191
+
192
+ if cfg.dataset_name == "voc":
193
+ val_loader_crop = None
194
+ else:
195
+ val_loader_crop = "center"
196
+
197
+ val_dataset = ContrastiveSegDataset(
198
+ pytorch_data_dir=pytorch_data_dir,
199
+ dataset_name=cfg.dataset_name,
200
+ crop_type=None,
201
+ image_set="val",
202
+ transform=get_transform(320, False, val_loader_crop),
203
+ target_transform=get_transform(320, True, val_loader_crop),
204
+ mask=True,
205
+ cfg=cfg,
206
+ )
207
+
208
+ # val_dataset = MaterializedDataset(val_dataset)
209
+ train_loader = DataLoader(
210
+ train_dataset,
211
+ cfg.batch_size,
212
+ shuffle=True,
213
+ num_workers=cfg.num_workers,
214
+ pin_memory=True,
215
+ )
216
+
217
+ if cfg.submitting_to_aml:
218
+ val_batch_size = 16
219
+ else:
220
+ val_batch_size = cfg.batch_size
221
+
222
+ val_loader = DataLoader(
223
+ val_dataset,
224
+ val_batch_size,
225
+ shuffle=False,
226
+ num_workers=cfg.num_workers,
227
+ pin_memory=True,
228
+ )
229
+
230
+ model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg)
231
+
232
+ tb_logger = TensorBoardLogger(join(log_dir, name), default_hp_metric=False)
233
+
234
+ if cfg.submitting_to_aml:
235
+ gpu_args = dict(gpus=1, val_check_interval=250)
236
+
237
+ if gpu_args["val_check_interval"] > len(train_loader):
238
+ gpu_args.pop("val_check_interval")
239
+
240
+ else:
241
+ gpu_args = dict(gpus=-1, accelerator="ddp", val_check_interval=cfg.val_freq)
242
+ # gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq)
243
+
244
+ if gpu_args["val_check_interval"] > len(train_loader) // 4:
245
+ gpu_args.pop("val_check_interval")
246
+
247
+ trainer = Trainer(
248
+ log_every_n_steps=cfg.scalar_log_freq,
249
+ logger=tb_logger,
250
+ max_steps=cfg.max_steps,
251
+ callbacks=[
252
+ ModelCheckpoint(
253
+ dirpath=join(checkpoint_dir, name),
254
+ every_n_train_steps=400,
255
+ save_top_k=2,
256
+ monitor="test/cluster/mIoU",
257
+ mode="max",
258
+ )
259
+ ],
260
+ **gpu_args
261
+ )
262
+ trainer.fit(model, train_loader, val_loader)
263
+
264
+
265
+ if __name__ == "__main__":
266
+ prep_args()
267
+ my_app()
biomap/unet.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader
7
+ import torch.nn as nn
8
+ from collections import defaultdict
9
+ import torchvision
10
+ import torch.nn.functional as F
11
+ from torch.utils.data.sampler import Sampler
12
+
13
+ class Block(nn.Module):
14
+ def __init__(self, in_ch, out_ch, padding='same'):
15
+ super().__init__()
16
+ self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=padding)
17
+ self.relu = nn.ReLU()
18
+ self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=padding)
19
+
20
+ def forward(self, x):
21
+ return self.conv2(self.relu(self.conv1(x)))
22
+
23
+
24
+ class Encoder(nn.Module):
25
+ def __init__(self, chs=(3,32,64,128,256)):
26
+ super().__init__()
27
+ self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
28
+ self.pool = nn.MaxPool2d(2)
29
+
30
+ def forward(self, x):
31
+ ftrs = []
32
+ for block in self.enc_blocks:
33
+ x = block(x)
34
+ ftrs.append(x)
35
+ x = self.pool(x)
36
+ return ftrs
37
+
38
+
39
+ class Decoder(nn.Module):
40
+ def __init__(self, chs=(256,128, 64, 32), aux_ch=70):
41
+ super().__init__()
42
+ upchs = tuple([chs[i]+aux_ch if i == 0 else chs[i] for i in range(len(chs))])
43
+ self.chs = chs
44
+ self.upchs = upchs
45
+ self.upconvs = nn.ModuleList([nn.ConvTranspose2d(upchs[i], upchs[i+1], 2, 2) for i in range(len(upchs)-1)])
46
+ self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
47
+
48
+ def forward(self, x, encoder_features):
49
+ for i in range(len(self.chs)-1):
50
+ # pdb.set_trace()
51
+ x = self.upconvs[i](x)
52
+ enc_ftrs = self.crop(encoder_features[i], x)
53
+ x = torch.cat([x, enc_ftrs], dim=1)
54
+ x = self.dec_blocks[i](x)
55
+ return x
56
+
57
+ def crop(self, enc_ftrs, x):
58
+ _, _, H, W = x.shape
59
+ enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
60
+ return enc_ftrs
61
+
62
+
63
+ class AuxUNet(nn.Module):
64
+ # UNet with auxiliary feature at the bottom
65
+ def __init__(self, enc_chs=(3,32,64,128,256), dec_chs=(256,128, 64, 32), aux_ch=70, num_class=7, retain_dim=False, out_sz=(224,224)):
66
+ super().__init__()
67
+ self.encoder = Encoder(enc_chs)
68
+ self.decoder = Decoder(dec_chs, aux_ch)
69
+ self.head = nn.Conv2d(dec_chs[-1], num_class, 1)
70
+ self.retain_dim = retain_dim
71
+
72
+ def forward(self, x, aux):
73
+ # aux: auxiliary feature at the bottom
74
+ enc_ftrs = self.encoder(x)
75
+ enc_ftrs[-1] = torch.cat((enc_ftrs[-1], aux), 1)
76
+ out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
77
+ out = self.head(out)
78
+ if self.retain_dim:
79
+ out = F.interpolate(out, out_sz)
80
+ return out
biomap/utils.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import os
3
+ from os.path import join
4
+ import io
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torch.multiprocessing
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import wget
12
+ from PIL import Image
13
+ from scipy.optimize import linear_sum_assignment
14
+ from torch._six import string_classes
15
+ from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
16
+ from torchmetrics import Metric
17
+ from torchvision import models
18
+ from torchvision import transforms as T
19
+ from torch.utils.tensorboard.summary import hparams
20
+ import matplotlib as mpl
21
+ torch.multiprocessing.set_sharing_strategy("file_system")
22
+ colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
23
+ class_names = (
24
+ "Buildings",
25
+ "Cultivation",
26
+ "Natural green",
27
+ "Wetland",
28
+ "Water",
29
+ "Infrastructure",
30
+ "Background",
31
+ )
32
+ bounds = list(np.arange(len(class_names) + 1) + 1)
33
+ cmap = mpl.colors.ListedColormap(colors)
34
+ norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
35
+
36
+ def compute_biodiv_score(image):
37
+ """Compute the biodiversity score of an image
38
+
39
+ Args:
40
+ image (_type_): _description_
41
+
42
+ Returns:
43
+ biodiversity_score: the biodiversity score associated to the landscape of the image
44
+ """
45
+ pix = np.array(image.getdata())
46
+ return np.mean(pix)
47
+
48
+ import cv2
49
+ def create_video(array_images, output_path="output.mp4"):
50
+ height, width, layers = array_images[0].shape
51
+ size = (width,height)
52
+
53
+ fourcc = cv2.VideoWriter_fourcc(*'VP90')
54
+ out = cv2.VideoWriter('output.mp4', fourcc, 2, size)
55
+
56
+ for i in range(len(array_images)):
57
+ out.write(array_images[i])
58
+ out.release()
59
+ return out
60
+
61
+
62
+
63
+ def transform_to_pil(outputs, alpha=0.3):
64
+ """Turn an ouput into a PIL
65
+
66
+ Args:
67
+ outputs (_type_): _description_
68
+ alpha (float, optional): _description_. Defaults to 0.3.
69
+
70
+ Returns:
71
+ _type_: _description_
72
+ """
73
+
74
+ # Transform img with torch
75
+ img = torch.moveaxis(prep_for_plot(outputs["img"][0]), -1, 0)
76
+ img = T.ToPILImage()(img)
77
+ # Transform label by saving it then open it
78
+ label = outputs["linear_preds"][0].numpy()
79
+ # image_label = Image.fromarray(label, mode="P")
80
+ plt.imsave("output/label.png", label, cmap=cmap)
81
+ image_label = Image.open("output/label.png")
82
+ # Overlay labels with img wit alpha
83
+ background = img.convert("RGBA")
84
+ overlay = image_label.convert("RGBA")
85
+ labeled_img = Image.blend(background, overlay, alpha)
86
+ labeled_img = labeled_img.convert("RGB")
87
+ return img, image_label, labeled_img
88
+
89
+
90
+ def prep_for_plot(img, rescale=True, resize=None):
91
+ if resize is not None:
92
+ img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear")
93
+ else:
94
+ img = img.unsqueeze(0)
95
+
96
+ plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0)
97
+ if rescale:
98
+ plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min())
99
+ return plot_img
100
+
101
+
102
+ def add_plot(writer, name, step):
103
+ buf = io.BytesIO()
104
+ plt.savefig(buf, format='jpeg', dpi=100)
105
+ buf.seek(0)
106
+ image = Image.open(buf)
107
+ image = T.ToTensor()(image)
108
+ writer.add_image(name, image, step)
109
+ plt.clf()
110
+ plt.close()
111
+
112
+
113
+ @torch.jit.script
114
+ def shuffle(x):
115
+ return x[torch.randperm(x.shape[0])]
116
+
117
+
118
+ def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step):
119
+ exp, ssi, sei = hparams(hparam_dict, metric_dict)
120
+ writer.file_writer.add_summary(exp)
121
+ writer.file_writer.add_summary(ssi)
122
+ writer.file_writer.add_summary(sei)
123
+ for k, v in metric_dict.items():
124
+ writer.add_scalar(k, v, global_step)
125
+
126
+
127
+ @torch.jit.script
128
+ def resize(classes: torch.Tensor, size: int):
129
+ return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False)
130
+
131
+
132
+ def one_hot_feats(labels, n_classes):
133
+ return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32)
134
+
135
+
136
+ def load_model(model_type, data_dir):
137
+ if model_type == "robust_resnet50":
138
+ model = models.resnet50(pretrained=False)
139
+ model_file = join(data_dir, 'imagenet_l2_3_0.pt')
140
+ if not os.path.exists(model_file):
141
+ wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt",
142
+ model_file)
143
+ model_weights = torch.load(model_file)
144
+ model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
145
+ 'model' in name}
146
+ model.load_state_dict(model_weights_modified)
147
+ model = nn.Sequential(*list(model.children())[:-1])
148
+ elif model_type == "densecl":
149
+ model = models.resnet50(pretrained=False)
150
+ model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth')
151
+ if not os.path.exists(model_file):
152
+ wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download",
153
+ model_file)
154
+ model_weights = torch.load(model_file)
155
+ # model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
156
+ # 'model' in name}
157
+ model.load_state_dict(model_weights['state_dict'], strict=False)
158
+ model = nn.Sequential(*list(model.children())[:-1])
159
+ elif model_type == "resnet50":
160
+ model = models.resnet50(pretrained=True)
161
+ model = nn.Sequential(*list(model.children())[:-1])
162
+ elif model_type == "mocov2":
163
+ model = models.resnet50(pretrained=False)
164
+ model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar')
165
+ if not os.path.exists(model_file):
166
+ wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/"
167
+ "moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file)
168
+ checkpoint = torch.load(model_file)
169
+ # rename moco pre-trained keys
170
+ state_dict = checkpoint['state_dict']
171
+ for k in list(state_dict.keys()):
172
+ # retain only encoder_q up to before the embedding layer
173
+ if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
174
+ # remove prefix
175
+ state_dict[k[len("module.encoder_q."):]] = state_dict[k]
176
+ # delete renamed or unused k
177
+ del state_dict[k]
178
+ msg = model.load_state_dict(state_dict, strict=False)
179
+ assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
180
+ model = nn.Sequential(*list(model.children())[:-1])
181
+ elif model_type == "densenet121":
182
+ model = models.densenet121(pretrained=True)
183
+ model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
184
+ elif model_type == "vgg11":
185
+ model = models.vgg11(pretrained=True)
186
+ model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
187
+ else:
188
+ raise ValueError("No model: {} found".format(model_type))
189
+
190
+ model.eval()
191
+ model.cuda()
192
+ return model
193
+
194
+
195
+ class UnNormalize(object):
196
+ def __init__(self, mean, std):
197
+ self.mean = mean
198
+ self.std = std
199
+
200
+ def __call__(self, image):
201
+ image2 = torch.clone(image)
202
+ for t, m, s in zip(image2, self.mean, self.std):
203
+ t.mul_(s).add_(m)
204
+ return image2
205
+
206
+
207
+ normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
208
+ unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
209
+
210
+
211
+ class ToTargetTensor(object):
212
+ def __call__(self, target):
213
+ return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
214
+
215
+
216
+ def prep_args():
217
+ import sys
218
+
219
+ old_args = sys.argv
220
+ new_args = [old_args.pop(0)]
221
+ while len(old_args) > 0:
222
+ arg = old_args.pop(0)
223
+ if len(arg.split("=")) == 2:
224
+ new_args.append(arg)
225
+ elif arg.startswith("--"):
226
+ new_args.append(arg[2:] + "=" + old_args.pop(0))
227
+ else:
228
+ raise ValueError("Unexpected arg style {}".format(arg))
229
+ sys.argv = new_args
230
+
231
+
232
+ def get_transform(res, is_label, crop_type):
233
+ if crop_type == "center":
234
+ cropper = T.CenterCrop(res)
235
+ elif crop_type == "random":
236
+ cropper = T.RandomCrop(res)
237
+ elif crop_type is None:
238
+ cropper = T.Lambda(lambda x: x)
239
+ res = (res, res)
240
+ else:
241
+ raise ValueError("Unknown Cropper {}".format(crop_type))
242
+ if is_label:
243
+ return T.Compose([T.Resize(res, Image.NEAREST),
244
+ cropper,
245
+ ToTargetTensor()])
246
+ else:
247
+ return T.Compose([T.Resize(res, Image.NEAREST),
248
+ cropper,
249
+ T.ToTensor(),
250
+ normalize])
251
+
252
+
253
+ def _remove_axes(ax):
254
+ ax.xaxis.set_major_formatter(plt.NullFormatter())
255
+ ax.yaxis.set_major_formatter(plt.NullFormatter())
256
+ ax.set_xticks([])
257
+ ax.set_yticks([])
258
+
259
+
260
+ def remove_axes(axes):
261
+ if len(axes.shape) == 2:
262
+ for ax1 in axes:
263
+ for ax in ax1:
264
+ _remove_axes(ax)
265
+ else:
266
+ for ax in axes:
267
+ _remove_axes(ax)
268
+
269
+
270
+ class UnsupervisedMetrics(Metric):
271
+ def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool,
272
+ dist_sync_on_step=True):
273
+ # call `self.add_state`for every internal state that is needed for the metrics computations
274
+ # dist_reduce_fx indicates the function that should be used to reduce
275
+ # state from multiple processes
276
+ super().__init__(dist_sync_on_step=dist_sync_on_step)
277
+
278
+ self.n_classes = n_classes
279
+ self.extra_clusters = extra_clusters
280
+ self.compute_hungarian = compute_hungarian
281
+ self.prefix = prefix
282
+ self.add_state("stats",
283
+ default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64),
284
+ dist_reduce_fx="sum")
285
+
286
+ def update(self, preds: torch.Tensor, target: torch.Tensor):
287
+ with torch.no_grad():
288
+ actual = target.reshape(-1)
289
+ preds = preds.reshape(-1)
290
+ mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes)
291
+ actual = actual[mask]
292
+ preds = preds[mask]
293
+ self.stats += torch.bincount(
294
+ (self.n_classes + self.extra_clusters) * actual + preds,
295
+ minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \
296
+ .reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device)
297
+
298
+ def map_clusters(self, clusters):
299
+ if self.extra_clusters == 0:
300
+ return torch.tensor(self.assignments[1])[clusters]
301
+ else:
302
+ missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])))
303
+ cluster_to_class = self.assignments[1]
304
+ for missing_entry in missing:
305
+ if missing_entry == cluster_to_class.shape[0]:
306
+ cluster_to_class = np.append(cluster_to_class, -1)
307
+ else:
308
+ cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1)
309
+ cluster_to_class = torch.tensor(cluster_to_class)
310
+ return cluster_to_class[clusters]
311
+
312
+ def compute(self):
313
+ if self.compute_hungarian:
314
+ self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True)
315
+ # print(self.assignments)
316
+ if self.extra_clusters == 0:
317
+ self.histogram = self.stats[np.argsort(self.assignments[1]), :]
318
+ if self.extra_clusters > 0:
319
+ self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True)
320
+ histogram = self.stats[self.assignments_t[1], :]
321
+ missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))
322
+ new_row = self.stats[missing, :].sum(0, keepdim=True)
323
+ histogram = torch.cat([histogram, new_row], axis=0)
324
+ new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device)
325
+ self.histogram = torch.cat([histogram, new_col], axis=1)
326
+ else:
327
+ self.assignments = (torch.arange(self.n_classes).unsqueeze(1),
328
+ torch.arange(self.n_classes).unsqueeze(1))
329
+ self.histogram = self.stats
330
+
331
+ tp = torch.diag(self.histogram)
332
+ fp = torch.sum(self.histogram, dim=0) - tp
333
+ fn = torch.sum(self.histogram, dim=1) - tp
334
+
335
+ iou = tp / (tp + fp + fn)
336
+ prc = tp / (tp + fn)
337
+ opc = torch.sum(tp) / torch.sum(self.histogram)
338
+
339
+ metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(),
340
+ self.prefix + "Accuracy": opc.item()}
341
+ return {k: 100 * v for k, v in metric_dict.items()}
342
+
343
+
344
+ def flexible_collate(batch):
345
+ r"""Puts each data field into a tensor with outer dimension batch size"""
346
+
347
+ elem = batch[0]
348
+ elem_type = type(elem)
349
+ if isinstance(elem, torch.Tensor):
350
+ out = None
351
+ if torch.utils.data.get_worker_info() is not None:
352
+ # If we're in a background process, concatenate directly into a
353
+ # shared memory tensor to avoid an extra copy
354
+ numel = sum([x.numel() for x in batch])
355
+ storage = elem.storage()._new_shared(numel)
356
+ out = elem.new(storage)
357
+ try:
358
+ return torch.stack(batch, 0, out=out)
359
+ except RuntimeError:
360
+ return batch
361
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
362
+ and elem_type.__name__ != 'string_':
363
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
364
+ # array of string classes and object
365
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
366
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
367
+
368
+ return flexible_collate([torch.as_tensor(b) for b in batch])
369
+ elif elem.shape == (): # scalars
370
+ return torch.as_tensor(batch)
371
+ elif isinstance(elem, float):
372
+ return torch.tensor(batch, dtype=torch.float64)
373
+ elif isinstance(elem, int):
374
+ return torch.tensor(batch)
375
+ elif isinstance(elem, string_classes):
376
+ return batch
377
+ elif isinstance(elem, collections.abc.Mapping):
378
+ return {key: flexible_collate([d[key] for d in batch]) for key in elem}
379
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
380
+ return elem_type(*(flexible_collate(samples) for samples in zip(*batch)))
381
+ elif isinstance(elem, collections.abc.Sequence):
382
+ # check to make sure that the elements in batch have consistent size
383
+ it = iter(batch)
384
+ elem_size = len(next(it))
385
+ if not all(len(elem) == elem_size for elem in it):
386
+ raise RuntimeError('each element in list of batch should be of equal size')
387
+ transposed = zip(*batch)
388
+ return [flexible_collate(samples) for samples in transposed]
389
+
390
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
biomap/utils_gee.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ import ee
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+
7
+ #Initialize
8
+ service_account = '[email protected]'
9
+ credentials = ee.ServiceAccountCredentials(service_account, '.private-key.json')
10
+ ee.Initialize(credentials)
11
+
12
+ #delete clouds
13
+ def maskS2clouds(image):
14
+ qa = image.select('QA60');
15
+
16
+ # // Bits 10 and 11 are clouds and cirrus, respectively.
17
+ cloudBitMask = 1 << 10;
18
+ cirrusBitMask = 1 << 11;
19
+
20
+ # // Both flags should be set to zero, indicating clear conditions.
21
+ mask = (qa.bitwiseAnd(cloudBitMask).eq(0))and(qa.bitwiseAnd(cirrusBitMask).eq(0))
22
+
23
+ return image.updateMask(mask).divide(10000);
24
+
25
+
26
+ #find ee_img
27
+ def extract_ee_img(location,start_date,end_date, width = 0.01 , len = 0.01) :
28
+ """Extract the earth engine image
29
+
30
+ Args:
31
+ location (list[float]):
32
+ start_date (str): the start date for finding an image
33
+ end_date (str): the end date for finding an image
34
+ width (float, optional): _description_. Defaults to 0.01.
35
+ len (float, optional): _description_. Defaults to 0.01.
36
+
37
+ Returns:
38
+ _type_: _description_
39
+ """
40
+ # define the polygone
41
+ polygone =[[[float(location[0])-0.01,float(location[1])+0.01],
42
+ [float(location[0])-0.01,float(location[1])-0.01],
43
+ [float(location[0])+0.01,float(location[1])-0.01],
44
+ [float(location[0])+0.01,float(location[1])+0.01],
45
+ ]]
46
+
47
+ #define the ee geometry
48
+ geometry = ee.Geometry.Polygon(polygone, None, False);
49
+
50
+ #extract the dataset
51
+ dataset = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')\
52
+ .filterDate(start_date, end_date)\
53
+ .filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE',1))\
54
+ .map(maskS2clouds)
55
+ return dataset.mean(), geometry
56
+
57
+
58
+
59
+ # Get URL
60
+ def get_url(ee_img, geometry, scale=5):
61
+ """Get the url of a dataset and a geometry
62
+
63
+ Args:
64
+ ee_img (ee.ImageCollection: meta data on the image
65
+ geometry (ee.Geometry.Polygon): geometry of the desired landscape
66
+ scale (int, optional): _description_. Defaults to 5.
67
+
68
+ Returns:
69
+ str: the url to use to ask the server
70
+ """
71
+ region = geometry
72
+
73
+ # collectionList = ee_img.toList(ee_img.size())
74
+ # collectionSize = collectionList.size().getInfo()
75
+ # for i in xrange(collectionSize):
76
+ # ee.batch.Export.image.toDrive(
77
+ # image = ee.Image(collectionList.get(i)).clip(rectangle),
78
+ # fileNamePrefix = 'foo' + str(i + 1),
79
+ # dimensions = '128x128').start()
80
+
81
+ url = ee_img.getDownloadURL({
82
+ # 'min': 0.0,
83
+ # 'max': 0.3,
84
+ 'bands': ['B4', 'B3', 'B2'],
85
+ 'region' : region,
86
+ 'scale' : scale,
87
+ 'format' : 'NPY'
88
+ })
89
+
90
+ return url
91
+
92
+ def extract_np_from_url(url):
93
+ """extract a numpy array based on a url
94
+
95
+ Args:
96
+ url (str): _description_
97
+
98
+ Returns:
99
+ numpyarray: response from earth engine as numpy
100
+ """
101
+ #get the response from url
102
+ response = requests.get(url)
103
+
104
+ #transform it into numpy
105
+ data = np.load(io.BytesIO(response.content))
106
+
107
+ #transform numpy of tuples to 3D numpy
108
+ temp1 = []
109
+
110
+ for x in data:
111
+ temp2 = []
112
+ for y in x :
113
+ temp2.append([z for z in y])
114
+ temp1.append(temp2)
115
+
116
+ data = np.array(temp1)
117
+
118
+ return data
119
+
120
+ #Fonction globale
121
+ def extract_img(location,start_date,end_date, width = 0.01 , len = 0.01,scale=5):
122
+ """Extract an image of the landscape at the selected longitude and latitude with the selected width and length
123
+
124
+ Args:
125
+ location (list[float]): [latitude of the center of the landscape, longitude of the center of the landscape]
126
+ start_date (str): the start date
127
+ end_date (str): _description_
128
+ width (float, optional): _description_. Defaults to 0.01.
129
+ len (float, optional): _description_. Defaults to 0.01.
130
+ scale (int, optional): _description_. Defaults to 5.
131
+
132
+ Returns:
133
+ img: image as numpy array
134
+ """
135
+ ee_img, geometry = extract_ee_img(location, width,start_date,end_date , len)
136
+ url = get_url(ee_img, geometry, scale)
137
+ img = extract_np_from_url(url)
138
+
139
+ return img
140
+
141
+ # transform img from numpy to PIL
142
+ def transform_ee_img(img, min = 0, max=0.3):
143
+ """Transform an img from numpy to PIL
144
+
145
+ Args:
146
+ img (numpy array): the original image as a numpy array
147
+ min (int, optional): _description_. Defaults to 0.
148
+ max (float, optional): _description_. Defaults to 0.3.
149
+
150
+ Returns:
151
+ img_test: a PIL image
152
+ """
153
+ img_test=img
154
+ img_test=np.minimum(img_test*255/max,np.ones(img.shape)*255)
155
+ img_test=np.uint8((np.rint(img_test)).astype(int))
156
+ plt.imshow(img_test)
157
+ return img_test
poetry.lock ADDED
@@ -0,0 +1,1625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[package]]
2
+ name = "absl-py"
3
+ version = "1.4.0"
4
+ description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
5
+ category = "main"
6
+ optional = false
7
+ python-versions = ">=3.6"
8
+
9
+ [[package]]
10
+ name = "aiofiles"
11
+ version = "23.1.0"
12
+ description = "File support for asyncio."
13
+ category = "main"
14
+ optional = false
15
+ python-versions = ">=3.7,<4.0"
16
+
17
+ [[package]]
18
+ name = "aiohttp"
19
+ version = "3.8.4"
20
+ description = "Async http client/server framework (asyncio)"
21
+ category = "main"
22
+ optional = false
23
+ python-versions = ">=3.6"
24
+
25
+ [package.dependencies]
26
+ aiosignal = ">=1.1.2"
27
+ async-timeout = ">=4.0.0a3,<5.0"
28
+ attrs = ">=17.3.0"
29
+ charset-normalizer = ">=2.0,<4.0"
30
+ frozenlist = ">=1.1.1"
31
+ multidict = ">=4.5,<7.0"
32
+ yarl = ">=1.0,<2.0"
33
+
34
+ [package.extras]
35
+ speedups = ["aiodns", "brotli", "cchardet"]
36
+
37
+ [[package]]
38
+ name = "aiosignal"
39
+ version = "1.3.1"
40
+ description = "aiosignal: a list of registered asynchronous callbacks"
41
+ category = "main"
42
+ optional = false
43
+ python-versions = ">=3.7"
44
+
45
+ [package.dependencies]
46
+ frozenlist = ">=1.1.0"
47
+
48
+ [[package]]
49
+ name = "altair"
50
+ version = "4.2.2"
51
+ description = "Altair: A declarative statistical visualization library for Python."
52
+ category = "main"
53
+ optional = false
54
+ python-versions = ">=3.7"
55
+
56
+ [package.dependencies]
57
+ entrypoints = "*"
58
+ jinja2 = "*"
59
+ jsonschema = ">=3.0"
60
+ numpy = "*"
61
+ pandas = ">=0.18"
62
+ toolz = "*"
63
+
64
+ [package.extras]
65
+ dev = ["black", "docutils", "ipython", "flake8", "pytest", "sphinx", "mistune (<2.0.0)", "m2r", "vega-datasets", "recommonmark"]
66
+
67
+ [[package]]
68
+ name = "antlr4-python3-runtime"
69
+ version = "4.9.3"
70
+ description = "ANTLR 4.9.3 runtime for Python 3.7"
71
+ category = "main"
72
+ optional = false
73
+ python-versions = "*"
74
+
75
+ [[package]]
76
+ name = "anyio"
77
+ version = "3.6.2"
78
+ description = "High level compatibility layer for multiple asynchronous event loop implementations"
79
+ category = "main"
80
+ optional = false
81
+ python-versions = ">=3.6.2"
82
+
83
+ [package.dependencies]
84
+ idna = ">=2.8"
85
+ sniffio = ">=1.1"
86
+
87
+ [package.extras]
88
+ doc = ["packaging", "sphinx-rtd-theme", "sphinx-autodoc-typehints (>=1.2.0)"]
89
+ test = ["coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "contextlib2", "uvloop (<0.15)", "mock (>=4)", "uvloop (>=0.15)"]
90
+ trio = ["trio (>=0.16,<0.22)"]
91
+
92
+ [[package]]
93
+ name = "async-timeout"
94
+ version = "4.0.2"
95
+ description = "Timeout context manager for asyncio programs"
96
+ category = "main"
97
+ optional = false
98
+ python-versions = ">=3.6"
99
+
100
+ [[package]]
101
+ name = "attrs"
102
+ version = "19.3.0"
103
+ description = "Classes Without Boilerplate"
104
+ category = "main"
105
+ optional = false
106
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
107
+
108
+ [package.extras]
109
+ azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "pytest-azurepipelines"]
110
+ dev = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "pre-commit"]
111
+ docs = ["sphinx", "zope.interface"]
112
+ tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"]
113
+
114
+ [[package]]
115
+ name = "cachetools"
116
+ version = "5.3.0"
117
+ description = "Extensible memoizing collections and decorators"
118
+ category = "main"
119
+ optional = false
120
+ python-versions = "~=3.7"
121
+
122
+ [[package]]
123
+ name = "certifi"
124
+ version = "2022.12.7"
125
+ description = "Python package for providing Mozilla's CA Bundle."
126
+ category = "main"
127
+ optional = false
128
+ python-versions = ">=3.6"
129
+
130
+ [[package]]
131
+ name = "charset-normalizer"
132
+ version = "3.1.0"
133
+ description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
134
+ category = "main"
135
+ optional = false
136
+ python-versions = ">=3.7.0"
137
+
138
+ [[package]]
139
+ name = "click"
140
+ version = "8.1.3"
141
+ description = "Composable command line interface toolkit"
142
+ category = "main"
143
+ optional = false
144
+ python-versions = ">=3.7"
145
+
146
+ [package.dependencies]
147
+ colorama = {version = "*", markers = "platform_system == \"Windows\""}
148
+
149
+ [[package]]
150
+ name = "colorama"
151
+ version = "0.4.6"
152
+ description = "Cross-platform colored terminal text."
153
+ category = "main"
154
+ optional = false
155
+ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
156
+
157
+ [[package]]
158
+ name = "contourpy"
159
+ version = "1.0.7"
160
+ description = "Python library for calculating contours of 2D quadrilateral grids"
161
+ category = "main"
162
+ optional = false
163
+ python-versions = ">=3.8"
164
+
165
+ [package.dependencies]
166
+ numpy = ">=1.16"
167
+
168
+ [package.extras]
169
+ bokeh = ["bokeh", "chromedriver", "selenium"]
170
+ docs = ["furo", "sphinx-copybutton"]
171
+ mypy = ["contourpy", "docutils-stubs", "mypy (==0.991)", "types-pillow"]
172
+ test = ["matplotlib", "pillow", "pytest"]
173
+ test-no-images = ["pytest"]
174
+
175
+ [[package]]
176
+ name = "cycler"
177
+ version = "0.11.0"
178
+ description = "Composable style cycles"
179
+ category = "main"
180
+ optional = false
181
+ python-versions = ">=3.6"
182
+
183
+ [[package]]
184
+ name = "earthengine-api"
185
+ version = "0.1.338"
186
+ description = "Earth Engine Python API"
187
+ category = "main"
188
+ optional = false
189
+ python-versions = "*"
190
+
191
+ [package.dependencies]
192
+ google-api-python-client = ">=1.12.1"
193
+ google-auth = ">=1.4.1"
194
+ google-auth-httplib2 = ">=0.0.3"
195
+ google-cloud-storage = "*"
196
+ httplib2 = ">=0.9.2,<1dev"
197
+ requests = "*"
198
+
199
+ [[package]]
200
+ name = "ee-extra"
201
+ version = "0.0.15"
202
+ description = "A ninja Python package behind rgee, rgeeExtra and eemont."
203
+ category = "main"
204
+ optional = false
205
+ python-versions = "*"
206
+
207
+ [package.dependencies]
208
+ earthengine-api = "*"
209
+
210
+ [[package]]
211
+ name = "entrypoints"
212
+ version = "0.4"
213
+ description = "Discover and load entry points from installed packages."
214
+ category = "main"
215
+ optional = false
216
+ python-versions = ">=3.6"
217
+
218
+ [[package]]
219
+ name = "fastapi"
220
+ version = "0.95.1"
221
+ description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
222
+ category = "main"
223
+ optional = false
224
+ python-versions = ">=3.7"
225
+
226
+ [package.dependencies]
227
+ pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0"
228
+ starlette = ">=0.26.1,<0.27.0"
229
+
230
+ [package.extras]
231
+ all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
232
+ dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>=0.12.0,<0.21.0)"]
233
+ doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"]
234
+ test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"]
235
+
236
+ [[package]]
237
+ name = "ffmpy"
238
+ version = "0.3.0"
239
+ description = "A simple Python wrapper for ffmpeg"
240
+ category = "main"
241
+ optional = false
242
+ python-versions = "*"
243
+
244
+ [[package]]
245
+ name = "filelock"
246
+ version = "3.11.0"
247
+ description = "A platform independent file lock."
248
+ category = "main"
249
+ optional = false
250
+ python-versions = ">=3.7"
251
+
252
+ [package.extras]
253
+ docs = ["furo (>=2023.3.27)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)", "sphinx (>=6.1.3)"]
254
+ testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)", "pytest (>=7.2.2)"]
255
+
256
+ [[package]]
257
+ name = "fonttools"
258
+ version = "4.39.3"
259
+ description = "Tools to manipulate font files"
260
+ category = "main"
261
+ optional = false
262
+ python-versions = ">=3.8"
263
+
264
+ [package.extras]
265
+ all = ["fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "zopfli (>=0.1.4)", "lz4 (>=1.7.4.2)", "matplotlib", "sympy", "skia-pathops (>=0.5.0)", "uharfbuzz (>=0.23.0)", "brotlicffi (>=0.8.0)", "scipy", "brotli (>=1.0.1)", "munkres", "unicodedata2 (>=15.0.0)", "xattr"]
266
+ graphite = ["lz4 (>=1.7.4.2)"]
267
+ interpolatable = ["scipy", "munkres"]
268
+ lxml = ["lxml (>=4.0,<5)"]
269
+ pathops = ["skia-pathops (>=0.5.0)"]
270
+ plot = ["matplotlib"]
271
+ repacker = ["uharfbuzz (>=0.23.0)"]
272
+ symfont = ["sympy"]
273
+ type1 = ["xattr"]
274
+ ufo = ["fs (>=2.2.0,<3)"]
275
+ unicode = ["unicodedata2 (>=15.0.0)"]
276
+ woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"]
277
+
278
+ [[package]]
279
+ name = "frozenlist"
280
+ version = "1.3.3"
281
+ description = "A list-like structure which implements collections.abc.MutableSequence"
282
+ category = "main"
283
+ optional = false
284
+ python-versions = ">=3.7"
285
+
286
+ [[package]]
287
+ name = "fsspec"
288
+ version = "2023.4.0"
289
+ description = "File-system specification"
290
+ category = "main"
291
+ optional = false
292
+ python-versions = ">=3.8"
293
+
294
+ [package.dependencies]
295
+ aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
296
+ requests = {version = "*", optional = true, markers = "extra == \"http\""}
297
+
298
+ [package.extras]
299
+ abfs = ["adlfs"]
300
+ adl = ["adlfs"]
301
+ arrow = ["pyarrow (>=1)"]
302
+ dask = ["dask", "distributed"]
303
+ devel = ["pytest", "pytest-cov"]
304
+ dropbox = ["dropboxdrivefs", "requests", "dropbox"]
305
+ full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
306
+ fuse = ["fusepy"]
307
+ gcs = ["gcsfs"]
308
+ git = ["pygit2"]
309
+ github = ["requests"]
310
+ gs = ["gcsfs"]
311
+ gui = ["panel"]
312
+ hdfs = ["pyarrow (>=1)"]
313
+ http = ["requests", "aiohttp (!=4.0.0a0,!=4.0.0a1)"]
314
+ libarchive = ["libarchive-c"]
315
+ oci = ["ocifs"]
316
+ s3 = ["s3fs"]
317
+ sftp = ["paramiko"]
318
+ smb = ["smbprotocol"]
319
+ ssh = ["paramiko"]
320
+ tqdm = ["tqdm"]
321
+
322
+ [[package]]
323
+ name = "google-api-core"
324
+ version = "2.11.0"
325
+ description = "Google API client core library"
326
+ category = "main"
327
+ optional = false
328
+ python-versions = ">=3.7"
329
+
330
+ [package.dependencies]
331
+ google-auth = ">=2.14.1,<3.0dev"
332
+ googleapis-common-protos = ">=1.56.2,<2.0dev"
333
+ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
334
+ requests = ">=2.18.0,<3.0.0dev"
335
+
336
+ [package.extras]
337
+ grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio-status (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.49.1,<2.0dev)"]
338
+ grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
339
+ grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
340
+
341
+ [[package]]
342
+ name = "google-api-python-client"
343
+ version = "2.85.0"
344
+ description = "Google API Client Library for Python"
345
+ category = "main"
346
+ optional = false
347
+ python-versions = ">=3.7"
348
+
349
+ [package.dependencies]
350
+ google-api-core = ">=1.31.5,<2.0.0 || >2.3.0,<3.0.0dev"
351
+ google-auth = ">=1.19.0,<3.0.0dev"
352
+ google-auth-httplib2 = ">=0.1.0"
353
+ httplib2 = ">=0.15.0,<1dev"
354
+ uritemplate = ">=3.0.1,<5"
355
+
356
+ [[package]]
357
+ name = "google-auth"
358
+ version = "2.17.3"
359
+ description = "Google Authentication Library"
360
+ category = "main"
361
+ optional = false
362
+ python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
363
+
364
+ [package.dependencies]
365
+ cachetools = ">=2.0.0,<6.0"
366
+ pyasn1-modules = ">=0.2.1"
367
+ rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""}
368
+ six = ">=1.9.0"
369
+
370
+ [package.extras]
371
+ aiohttp = ["requests (>=2.20.0,<3.0.0dev)", "aiohttp (>=3.6.2,<4.0.0dev)"]
372
+ enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
373
+ pyopenssl = ["pyopenssl (>=20.0.0)", "cryptography (>=38.0.3)"]
374
+ reauth = ["pyu2f (>=0.1.5)"]
375
+ requests = ["requests (>=2.20.0,<3.0.0dev)"]
376
+
377
+ [[package]]
378
+ name = "google-auth-httplib2"
379
+ version = "0.1.0"
380
+ description = "Google Authentication Library: httplib2 transport"
381
+ category = "main"
382
+ optional = false
383
+ python-versions = "*"
384
+
385
+ [package.dependencies]
386
+ google-auth = "*"
387
+ httplib2 = ">=0.15.0"
388
+ six = "*"
389
+
390
+ [[package]]
391
+ name = "google-auth-oauthlib"
392
+ version = "0.4.6"
393
+ description = "Google Authentication Library"
394
+ category = "main"
395
+ optional = false
396
+ python-versions = ">=3.6"
397
+
398
+ [package.dependencies]
399
+ google-auth = ">=1.0.0"
400
+ requests-oauthlib = ">=0.7.0"
401
+
402
+ [package.extras]
403
+ tool = ["click (>=6.0.0)"]
404
+
405
+ [[package]]
406
+ name = "google-cloud-core"
407
+ version = "2.3.2"
408
+ description = "Google Cloud API client core library"
409
+ category = "main"
410
+ optional = false
411
+ python-versions = ">=3.7"
412
+
413
+ [package.dependencies]
414
+ google-api-core = ">=1.31.6,<2.0.0 || >2.3.0,<3.0.0dev"
415
+ google-auth = ">=1.25.0,<3.0dev"
416
+
417
+ [package.extras]
418
+ grpc = ["grpcio (>=1.38.0,<2.0dev)"]
419
+
420
+ [[package]]
421
+ name = "google-cloud-storage"
422
+ version = "2.8.0"
423
+ description = "Google Cloud Storage API client library"
424
+ category = "main"
425
+ optional = false
426
+ python-versions = ">=3.7"
427
+
428
+ [package.dependencies]
429
+ google-api-core = ">=1.31.5,<2.0.0 || >2.3.0,<3.0.0dev"
430
+ google-auth = ">=1.25.0,<3.0dev"
431
+ google-cloud-core = ">=2.3.0,<3.0dev"
432
+ google-resumable-media = ">=2.3.2"
433
+ requests = ">=2.18.0,<3.0.0dev"
434
+
435
+ [package.extras]
436
+ protobuf = ["protobuf (<5.0.0dev)"]
437
+
438
+ [[package]]
439
+ name = "google-crc32c"
440
+ version = "1.5.0"
441
+ description = "A python wrapper of the C library 'Google CRC32C'"
442
+ category = "main"
443
+ optional = false
444
+ python-versions = ">=3.7"
445
+
446
+ [package.extras]
447
+ testing = ["pytest"]
448
+
449
+ [[package]]
450
+ name = "google-resumable-media"
451
+ version = "2.4.1"
452
+ description = "Utilities for Google Media Downloads and Resumable Uploads"
453
+ category = "main"
454
+ optional = false
455
+ python-versions = ">= 3.7"
456
+
457
+ [package.dependencies]
458
+ google-crc32c = ">=1.0,<2.0dev"
459
+
460
+ [package.extras]
461
+ aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)"]
462
+ requests = ["requests (>=2.18.0,<3.0.0dev)"]
463
+
464
+ [[package]]
465
+ name = "googleapis-common-protos"
466
+ version = "1.59.0"
467
+ description = "Common protobufs used in Google APIs"
468
+ category = "main"
469
+ optional = false
470
+ python-versions = ">=3.7"
471
+
472
+ [package.dependencies]
473
+ protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
474
+
475
+ [package.extras]
476
+ grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
477
+
478
+ [[package]]
479
+ name = "gradio"
480
+ version = "3.27.0"
481
+ description = "Python library for easily interacting with trained machine learning models"
482
+ category = "main"
483
+ optional = false
484
+ python-versions = ">=3.7"
485
+
486
+ [package.dependencies]
487
+ aiofiles = "*"
488
+ aiohttp = "*"
489
+ altair = ">=4.2.0"
490
+ fastapi = "*"
491
+ ffmpy = "*"
492
+ gradio-client = ">=0.1.3"
493
+ httpx = "*"
494
+ huggingface-hub = ">=0.13.0"
495
+ jinja2 = "*"
496
+ markdown-it-py = {version = ">=2.0.0", extras = ["linkify"]}
497
+ markupsafe = "*"
498
+ matplotlib = "*"
499
+ mdit-py-plugins = "<=0.3.3"
500
+ numpy = "*"
501
+ orjson = "*"
502
+ pandas = "*"
503
+ pillow = "*"
504
+ pydantic = "*"
505
+ pydub = "*"
506
+ python-multipart = "*"
507
+ pyyaml = "*"
508
+ requests = "*"
509
+ semantic-version = "*"
510
+ typing-extensions = "*"
511
+ uvicorn = "*"
512
+ websockets = ">=10.0"
513
+
514
+ [[package]]
515
+ name = "gradio-client"
516
+ version = "0.1.3"
517
+ description = "Python library for easily interacting with trained machine learning models"
518
+ category = "main"
519
+ optional = false
520
+ python-versions = ">=3.7"
521
+
522
+ [package.dependencies]
523
+ fsspec = "*"
524
+ httpx = "*"
525
+ huggingface-hub = ">=0.13.0"
526
+ packaging = "*"
527
+ requests = "*"
528
+ typing-extensions = "*"
529
+ websockets = "*"
530
+
531
+ [[package]]
532
+ name = "grpcio"
533
+ version = "1.53.0"
534
+ description = "HTTP/2-based RPC framework"
535
+ category = "main"
536
+ optional = false
537
+ python-versions = ">=3.7"
538
+
539
+ [package.extras]
540
+ protobuf = ["grpcio-tools (>=1.53.0)"]
541
+
542
+ [[package]]
543
+ name = "h11"
544
+ version = "0.14.0"
545
+ description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
546
+ category = "main"
547
+ optional = false
548
+ python-versions = ">=3.7"
549
+
550
+ [[package]]
551
+ name = "httpcore"
552
+ version = "0.17.0"
553
+ description = "A minimal low-level HTTP client."
554
+ category = "main"
555
+ optional = false
556
+ python-versions = ">=3.7"
557
+
558
+ [package.dependencies]
559
+ anyio = ">=3.0,<5.0"
560
+ certifi = "*"
561
+ h11 = ">=0.13,<0.15"
562
+ sniffio = ">=1.0.0,<2.0.0"
563
+
564
+ [package.extras]
565
+ http2 = ["h2 (>=3,<5)"]
566
+ socks = ["socksio (>=1.0.0,<2.0.0)"]
567
+
568
+ [[package]]
569
+ name = "httplib2"
570
+ version = "0.22.0"
571
+ description = "A comprehensive HTTP client library."
572
+ category = "main"
573
+ optional = false
574
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
575
+
576
+ [package.dependencies]
577
+ pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""}
578
+
579
+ [[package]]
580
+ name = "httpx"
581
+ version = "0.24.0"
582
+ description = "The next generation HTTP client."
583
+ category = "main"
584
+ optional = false
585
+ python-versions = ">=3.7"
586
+
587
+ [package.dependencies]
588
+ certifi = "*"
589
+ httpcore = ">=0.15.0,<0.18.0"
590
+ idna = "*"
591
+ sniffio = "*"
592
+
593
+ [package.extras]
594
+ brotli = ["brotli", "brotlicffi"]
595
+ cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"]
596
+ http2 = ["h2 (>=3,<5)"]
597
+ socks = ["socksio (>=1.0.0,<2.0.0)"]
598
+
599
+ [[package]]
600
+ name = "huggingface-hub"
601
+ version = "0.13.4"
602
+ description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
603
+ category = "main"
604
+ optional = false
605
+ python-versions = ">=3.7.0"
606
+
607
+ [package.dependencies]
608
+ filelock = "*"
609
+ packaging = ">=20.9"
610
+ pyyaml = ">=5.1"
611
+ requests = "*"
612
+ tqdm = ">=4.42.1"
613
+ typing-extensions = ">=3.7.4.3"
614
+
615
+ [package.extras]
616
+ all = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
617
+ cli = ["InquirerPy (==0.3.4)"]
618
+ dev = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
619
+ fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"]
620
+ quality = ["black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)"]
621
+ tensorflow = ["tensorflow", "pydot", "graphviz"]
622
+ testing = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow"]
623
+ torch = ["torch"]
624
+ typing = ["types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
625
+
626
+ [[package]]
627
+ name = "hydra-client"
628
+ version = "0.5.1"
629
+ description = "Client library for ORY Hydra (OAuth2 and OpenID Connect provider)"
630
+ category = "main"
631
+ optional = false
632
+ python-versions = ">=3.7,<4.0"
633
+
634
+ [package.dependencies]
635
+ attrs = ">=19.2,<20.0"
636
+ python-dateutil = ">=2.8,<3.0"
637
+ requests = ">=2.21,<3.0"
638
+ requests-oauthlib = ">=1.0,<2.0"
639
+
640
+ [[package]]
641
+ name = "hydra-core"
642
+ version = "1.3.1"
643
+ description = "A framework for elegantly configuring complex applications"
644
+ category = "main"
645
+ optional = false
646
+ python-versions = "*"
647
+
648
+ [package.dependencies]
649
+ antlr4-python3-runtime = ">=4.9.0,<4.10.0"
650
+ omegaconf = ">=2.2,<2.4"
651
+ packaging = "*"
652
+
653
+ [[package]]
654
+ name = "idna"
655
+ version = "3.4"
656
+ description = "Internationalized Domain Names in Applications (IDNA)"
657
+ category = "main"
658
+ optional = false
659
+ python-versions = ">=3.5"
660
+
661
+ [[package]]
662
+ name = "jinja2"
663
+ version = "3.1.2"
664
+ description = "A very fast and expressive template engine."
665
+ category = "main"
666
+ optional = false
667
+ python-versions = ">=3.7"
668
+
669
+ [package.dependencies]
670
+ MarkupSafe = ">=2.0"
671
+
672
+ [package.extras]
673
+ i18n = ["Babel (>=2.7)"]
674
+
675
+ [[package]]
676
+ name = "jsonschema"
677
+ version = "4.17.3"
678
+ description = "An implementation of JSON Schema validation for Python"
679
+ category = "main"
680
+ optional = false
681
+ python-versions = ">=3.7"
682
+
683
+ [package.dependencies]
684
+ attrs = ">=17.4.0"
685
+ pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2"
686
+
687
+ [package.extras]
688
+ format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
689
+ format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"]
690
+
691
+ [[package]]
692
+ name = "kiwisolver"
693
+ version = "1.4.4"
694
+ description = "A fast implementation of the Cassowary constraint solver"
695
+ category = "main"
696
+ optional = false
697
+ python-versions = ">=3.7"
698
+
699
+ [[package]]
700
+ name = "lightning-utilities"
701
+ version = "0.8.0"
702
+ description = "PyTorch Lightning Sample project."
703
+ category = "main"
704
+ optional = false
705
+ python-versions = ">=3.7"
706
+
707
+ [package.dependencies]
708
+ packaging = ">=17.1"
709
+ typing-extensions = "*"
710
+
711
+ [package.extras]
712
+ cli = ["fire"]
713
+ docs = ["sphinx (>=4.0,<5.0)"]
714
+ test = ["coverage (==6.5.0)"]
715
+ typing = ["mypy (>=1.0.0)"]
716
+
717
+ [[package]]
718
+ name = "linkify-it-py"
719
+ version = "2.0.0"
720
+ description = "Links recognition library with FULL unicode support."
721
+ category = "main"
722
+ optional = false
723
+ python-versions = ">=3.6"
724
+
725
+ [package.dependencies]
726
+ uc-micro-py = "*"
727
+
728
+ [package.extras]
729
+ benchmark = ["pytest", "pytest-benchmark"]
730
+ dev = ["pre-commit", "isort", "flake8", "black"]
731
+ doc = ["sphinx", "sphinx-book-theme", "myst-parser"]
732
+ test = ["coverage", "pytest", "pytest-cov"]
733
+
734
+ [[package]]
735
+ name = "markdown"
736
+ version = "3.4.3"
737
+ description = "Python implementation of John Gruber's Markdown."
738
+ category = "main"
739
+ optional = false
740
+ python-versions = ">=3.7"
741
+
742
+ [package.extras]
743
+ testing = ["coverage", "pyyaml"]
744
+
745
+ [[package]]
746
+ name = "markdown-it-py"
747
+ version = "2.2.0"
748
+ description = "Python port of markdown-it. Markdown parsing, done right!"
749
+ category = "main"
750
+ optional = false
751
+ python-versions = ">=3.7"
752
+
753
+ [package.dependencies]
754
+ linkify-it-py = {version = ">=1,<3", optional = true, markers = "extra == \"linkify\""}
755
+ mdurl = ">=0.1,<1.0"
756
+
757
+ [package.extras]
758
+ benchmarking = ["psutil", "pytest", "pytest-benchmark"]
759
+ code_style = ["pre-commit (>=3.0,<4.0)"]
760
+ compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
761
+ linkify = ["linkify-it-py (>=1,<3)"]
762
+ plugins = ["mdit-py-plugins"]
763
+ profiling = ["gprof2dot"]
764
+ rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx-book-theme"]
765
+ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
766
+
767
+ [[package]]
768
+ name = "markupsafe"
769
+ version = "2.1.2"
770
+ description = "Safely add untrusted strings to HTML/XML markup."
771
+ category = "main"
772
+ optional = false
773
+ python-versions = ">=3.7"
774
+
775
+ [[package]]
776
+ name = "matplotlib"
777
+ version = "3.7.1"
778
+ description = "Python plotting package"
779
+ category = "main"
780
+ optional = false
781
+ python-versions = ">=3.8"
782
+
783
+ [package.dependencies]
784
+ contourpy = ">=1.0.1"
785
+ cycler = ">=0.10"
786
+ fonttools = ">=4.22.0"
787
+ kiwisolver = ">=1.0.1"
788
+ numpy = ">=1.20"
789
+ packaging = ">=20.0"
790
+ pillow = ">=6.2.0"
791
+ pyparsing = ">=2.3.1"
792
+ python-dateutil = ">=2.7"
793
+ setuptools_scm = ">=7"
794
+
795
+ [[package]]
796
+ name = "mdit-py-plugins"
797
+ version = "0.3.3"
798
+ description = "Collection of plugins for markdown-it-py"
799
+ category = "main"
800
+ optional = false
801
+ python-versions = ">=3.7"
802
+
803
+ [package.dependencies]
804
+ markdown-it-py = ">=1.0.0,<3.0.0"
805
+
806
+ [package.extras]
807
+ code_style = ["pre-commit"]
808
+ rtd = ["attrs", "myst-parser (>=0.16.1,<0.17.0)", "sphinx-book-theme (>=0.1.0,<0.2.0)"]
809
+ testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
810
+
811
+ [[package]]
812
+ name = "mdurl"
813
+ version = "0.1.2"
814
+ description = "Markdown URL utilities"
815
+ category = "main"
816
+ optional = false
817
+ python-versions = ">=3.7"
818
+
819
+ [[package]]
820
+ name = "multidict"
821
+ version = "6.0.4"
822
+ description = "multidict implementation"
823
+ category = "main"
824
+ optional = false
825
+ python-versions = ">=3.7"
826
+
827
+ [[package]]
828
+ name = "numpy"
829
+ version = "1.24.2"
830
+ description = "Fundamental package for array computing in Python"
831
+ category = "main"
832
+ optional = false
833
+ python-versions = ">=3.8"
834
+
835
+ [[package]]
836
+ name = "nvidia-cublas-cu11"
837
+ version = "11.10.3.66"
838
+ description = "CUBLAS native runtime libraries"
839
+ category = "main"
840
+ optional = false
841
+ python-versions = ">=3"
842
+
843
+ [[package]]
844
+ name = "nvidia-cuda-nvrtc-cu11"
845
+ version = "11.7.99"
846
+ description = "NVRTC native runtime libraries"
847
+ category = "main"
848
+ optional = false
849
+ python-versions = ">=3"
850
+
851
+ [[package]]
852
+ name = "nvidia-cuda-runtime-cu11"
853
+ version = "11.7.99"
854
+ description = "CUDA Runtime native Libraries"
855
+ category = "main"
856
+ optional = false
857
+ python-versions = ">=3"
858
+
859
+ [[package]]
860
+ name = "nvidia-cudnn-cu11"
861
+ version = "8.5.0.96"
862
+ description = "cuDNN runtime libraries"
863
+ category = "main"
864
+ optional = false
865
+ python-versions = ">=3"
866
+
867
+ [[package]]
868
+ name = "oauthlib"
869
+ version = "3.2.2"
870
+ description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
871
+ category = "main"
872
+ optional = false
873
+ python-versions = ">=3.6"
874
+
875
+ [package.extras]
876
+ rsa = ["cryptography (>=3.0.0)"]
877
+ signals = ["blinker (>=1.4.0)"]
878
+ signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
879
+
880
+ [[package]]
881
+ name = "omegaconf"
882
+ version = "2.3.0"
883
+ description = "A flexible configuration library"
884
+ category = "main"
885
+ optional = false
886
+ python-versions = ">=3.6"
887
+
888
+ [package.dependencies]
889
+ antlr4-python3-runtime = ">=4.9.0,<4.10.0"
890
+ PyYAML = ">=5.1.0"
891
+
892
+ [[package]]
893
+ name = "opencv-python"
894
+ version = "4.7.0.72"
895
+ description = "Wrapper package for OpenCV python bindings."
896
+ category = "main"
897
+ optional = false
898
+ python-versions = ">=3.6"
899
+
900
+ [package.dependencies]
901
+ numpy = [
902
+ {version = ">=1.21.2", markers = "python_version >= \"3.10\""},
903
+ {version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
904
+ {version = ">=1.22.0", markers = "python_version >= \"3.11\""},
905
+ {version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
906
+ {version = ">=1.17.0", markers = "python_version >= \"3.7\""},
907
+ {version = ">=1.17.3", markers = "python_version >= \"3.8\""},
908
+ ]
909
+
910
+ [[package]]
911
+ name = "orjson"
912
+ version = "3.8.10"
913
+ description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
914
+ category = "main"
915
+ optional = false
916
+ python-versions = ">= 3.7"
917
+
918
+ [[package]]
919
+ name = "packaging"
920
+ version = "23.1"
921
+ description = "Core utilities for Python packages"
922
+ category = "main"
923
+ optional = false
924
+ python-versions = ">=3.7"
925
+
926
+ [[package]]
927
+ name = "pandas"
928
+ version = "2.0.0"
929
+ description = "Powerful data structures for data analysis, time series, and statistics"
930
+ category = "main"
931
+ optional = false
932
+ python-versions = ">=3.8"
933
+
934
+ [package.dependencies]
935
+ numpy = [
936
+ {version = ">=1.21.0", markers = "python_version >= \"3.10\""},
937
+ {version = ">=1.23.2", markers = "python_version >= \"3.11\""},
938
+ ]
939
+ python-dateutil = ">=2.8.2"
940
+ pytz = ">=2020.1"
941
+ tzdata = ">=2022.1"
942
+
943
+ [package.extras]
944
+ all = ["beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "PyQt5 (>=5.15.1)", "pyreadstat (>=1.1.2)", "pytest (>=7.0.0)", "pytest-xdist (>=2.2.0)", "pytest-asyncio (>=0.17.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "scipy (>=1.7.1)", "s3fs (>=2021.08.0)", "SQLAlchemy (>=1.4.16)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"]
945
+ aws = ["s3fs (>=2021.08.0)"]
946
+ clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"]
947
+ compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"]
948
+ computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"]
949
+ excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"]
950
+ feather = ["pyarrow (>=7.0.0)"]
951
+ fss = ["fsspec (>=2021.07.0)"]
952
+ gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"]
953
+ hdf5 = ["tables (>=3.6.1)"]
954
+ html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"]
955
+ mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"]
956
+ output_formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"]
957
+ parquet = ["pyarrow (>=7.0.0)"]
958
+ performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"]
959
+ plot = ["matplotlib (>=3.6.1)"]
960
+ postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"]
961
+ spss = ["pyreadstat (>=1.1.2)"]
962
+ sql-other = ["SQLAlchemy (>=1.4.16)"]
963
+ test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-xdist (>=2.2.0)", "pytest-asyncio (>=0.17.0)"]
964
+ xml = ["lxml (>=4.6.3)"]
965
+
966
+ [[package]]
967
+ name = "pillow"
968
+ version = "8.4.0"
969
+ description = "Python Imaging Library (Fork)"
970
+ category = "main"
971
+ optional = false
972
+ python-versions = ">=3.6"
973
+
974
+ [[package]]
975
+ name = "plotly"
976
+ version = "5.14.1"
977
+ description = "An open-source, interactive data visualization library for Python"
978
+ category = "main"
979
+ optional = false
980
+ python-versions = ">=3.6"
981
+
982
+ [package.dependencies]
983
+ packaging = "*"
984
+ tenacity = ">=6.2.0"
985
+
986
+ [[package]]
987
+ name = "protobuf"
988
+ version = "3.20.3"
989
+ description = "Protocol Buffers"
990
+ category = "main"
991
+ optional = false
992
+ python-versions = ">=3.7"
993
+
994
+ [[package]]
995
+ name = "pyasn1"
996
+ version = "0.4.8"
997
+ description = "ASN.1 types and codecs"
998
+ category = "main"
999
+ optional = false
1000
+ python-versions = "*"
1001
+
1002
+ [[package]]
1003
+ name = "pyasn1-modules"
1004
+ version = "0.2.8"
1005
+ description = "A collection of ASN.1-based protocols modules."
1006
+ category = "main"
1007
+ optional = false
1008
+ python-versions = "*"
1009
+
1010
+ [package.dependencies]
1011
+ pyasn1 = ">=0.4.6,<0.5.0"
1012
+
1013
+ [[package]]
1014
+ name = "pydantic"
1015
+ version = "1.10.7"
1016
+ description = "Data validation and settings management using python type hints"
1017
+ category = "main"
1018
+ optional = false
1019
+ python-versions = ">=3.7"
1020
+
1021
+ [package.dependencies]
1022
+ typing-extensions = ">=4.2.0"
1023
+
1024
+ [package.extras]
1025
+ dotenv = ["python-dotenv (>=0.10.4)"]
1026
+ email = ["email-validator (>=1.0.3)"]
1027
+
1028
+ [[package]]
1029
+ name = "pydub"
1030
+ version = "0.25.1"
1031
+ description = "Manipulate audio with an simple and easy high level interface"
1032
+ category = "main"
1033
+ optional = false
1034
+ python-versions = "*"
1035
+
1036
+ [[package]]
1037
+ name = "pyparsing"
1038
+ version = "3.0.9"
1039
+ description = "pyparsing module - Classes and methods to define and execute parsing grammars"
1040
+ category = "main"
1041
+ optional = false
1042
+ python-versions = ">=3.6.8"
1043
+
1044
+ [package.extras]
1045
+ diagrams = ["railroad-diagrams", "jinja2"]
1046
+
1047
+ [[package]]
1048
+ name = "pyrsistent"
1049
+ version = "0.19.3"
1050
+ description = "Persistent/Functional/Immutable data structures"
1051
+ category = "main"
1052
+ optional = false
1053
+ python-versions = ">=3.7"
1054
+
1055
+ [[package]]
1056
+ name = "python-dateutil"
1057
+ version = "2.8.2"
1058
+ description = "Extensions to the standard Python datetime module"
1059
+ category = "main"
1060
+ optional = false
1061
+ python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
1062
+
1063
+ [package.dependencies]
1064
+ six = ">=1.5"
1065
+
1066
+ [[package]]
1067
+ name = "python-multipart"
1068
+ version = "0.0.6"
1069
+ description = "A streaming multipart parser for Python"
1070
+ category = "main"
1071
+ optional = false
1072
+ python-versions = ">=3.7"
1073
+
1074
+ [package.extras]
1075
+ dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pytest (==7.2.0)", "pyyaml (==5.1)"]
1076
+
1077
+ [[package]]
1078
+ name = "pytorch-lightning"
1079
+ version = "1.9.0"
1080
+ description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
1081
+ category = "main"
1082
+ optional = false
1083
+ python-versions = ">=3.7"
1084
+
1085
+ [package.dependencies]
1086
+ fsspec = {version = ">2021.06.0", extras = ["http"]}
1087
+ lightning-utilities = ">=0.4.2"
1088
+ numpy = ">=1.17.2"
1089
+ packaging = ">=17.1"
1090
+ PyYAML = ">=5.4"
1091
+ torch = ">=1.10.0"
1092
+ torchmetrics = ">=0.7.0"
1093
+ tqdm = ">=4.57.0"
1094
+ typing-extensions = ">=4.0.0"
1095
+
1096
+ [package.extras]
1097
+ all = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)", "hivemind (==1.1.5)"]
1098
+ deepspeed = ["deepspeed (>=0.6.0)"]
1099
+ dev = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)", "coverage (==6.5.0)", "codecov (==2.1.12)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "pre-commit (==2.20.0)", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime (<1.14.0)", "psutil (<5.9.5)", "pandas (>1.0)", "fastapi (<0.87.0)", "uvicorn (<0.19.1)", "tensorboard (>=2.9.1)", "protobuf (<=3.20.1)", "hivemind (==1.1.5)"]
1100
+ examples = ["torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)"]
1101
+ extra = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)"]
1102
+ fairscale = ["fairscale (>=0.4.5)"]
1103
+ hivemind = ["hivemind (==1.1.5)"]
1104
+ horovod = ["horovod (>=0.21.2,!=0.24.0)"]
1105
+ strategies = ["fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "hivemind (==1.1.5)"]
1106
+ test = ["coverage (==6.5.0)", "codecov (==2.1.12)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "pre-commit (==2.20.0)", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime (<1.14.0)", "psutil (<5.9.5)", "pandas (>1.0)", "fastapi (<0.87.0)", "uvicorn (<0.19.1)", "tensorboard (>=2.9.1)", "protobuf (<=3.20.1)"]
1107
+
1108
+ [[package]]
1109
+ name = "pytz"
1110
+ version = "2023.3"
1111
+ description = "World timezone definitions, modern and historical"
1112
+ category = "main"
1113
+ optional = false
1114
+ python-versions = "*"
1115
+
1116
+ [[package]]
1117
+ name = "pyyaml"
1118
+ version = "6.0"
1119
+ description = "YAML parser and emitter for Python"
1120
+ category = "main"
1121
+ optional = false
1122
+ python-versions = ">=3.6"
1123
+
1124
+ [[package]]
1125
+ name = "requests"
1126
+ version = "2.28.2"
1127
+ description = "Python HTTP for Humans."
1128
+ category = "main"
1129
+ optional = false
1130
+ python-versions = ">=3.7, <4"
1131
+
1132
+ [package.dependencies]
1133
+ certifi = ">=2017.4.17"
1134
+ charset-normalizer = ">=2,<4"
1135
+ idna = ">=2.5,<4"
1136
+ urllib3 = ">=1.21.1,<1.27"
1137
+
1138
+ [package.extras]
1139
+ socks = ["PySocks (>=1.5.6,!=1.5.7)"]
1140
+ use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"]
1141
+
1142
+ [[package]]
1143
+ name = "requests-oauthlib"
1144
+ version = "1.3.1"
1145
+ description = "OAuthlib authentication support for Requests."
1146
+ category = "main"
1147
+ optional = false
1148
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
1149
+
1150
+ [package.dependencies]
1151
+ oauthlib = ">=3.0.0"
1152
+ requests = ">=2.0.0"
1153
+
1154
+ [package.extras]
1155
+ rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
1156
+
1157
+ [[package]]
1158
+ name = "rsa"
1159
+ version = "4.9"
1160
+ description = "Pure-Python RSA implementation"
1161
+ category = "main"
1162
+ optional = false
1163
+ python-versions = ">=3.6,<4"
1164
+
1165
+ [package.dependencies]
1166
+ pyasn1 = ">=0.1.3"
1167
+
1168
+ [[package]]
1169
+ name = "scipy"
1170
+ version = "1.10.1"
1171
+ description = "Fundamental algorithms for scientific computing in Python"
1172
+ category = "main"
1173
+ optional = false
1174
+ python-versions = "<3.12,>=3.8"
1175
+
1176
+ [package.dependencies]
1177
+ numpy = ">=1.19.5,<1.27.0"
1178
+
1179
+ [package.extras]
1180
+ test = ["pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack", "pooch"]
1181
+ doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-design (>=0.2.0)", "matplotlib (>2)", "numpydoc"]
1182
+ dev = ["mypy", "typing-extensions", "pycodestyle", "flake8", "rich-click", "click", "doit (>=0.36.0)", "pydevtool"]
1183
+
1184
+ [[package]]
1185
+ name = "seaborn"
1186
+ version = "0.12.2"
1187
+ description = "Statistical data visualization"
1188
+ category = "main"
1189
+ optional = false
1190
+ python-versions = ">=3.7"
1191
+
1192
+ [package.dependencies]
1193
+ matplotlib = ">=3.1,<3.6.1 || >3.6.1"
1194
+ numpy = ">=1.17,<1.24.0 || >1.24.0"
1195
+ pandas = ">=0.25"
1196
+
1197
+ [package.extras]
1198
+ dev = ["pytest", "pytest-cov", "pytest-xdist", "flake8", "mypy", "pandas-stubs", "pre-commit", "flit"]
1199
+ docs = ["numpydoc", "nbconvert", "ipykernel", "sphinx-copybutton", "sphinx-issues", "sphinx-design", "pyyaml", "pydata_sphinx_theme (==0.10.0rc2)"]
1200
+ stats = ["scipy (>=1.3)", "statsmodels (>=0.10)"]
1201
+
1202
+ [[package]]
1203
+ name = "semantic-version"
1204
+ version = "2.10.0"
1205
+ description = "A library implementing the 'SemVer' scheme."
1206
+ category = "main"
1207
+ optional = false
1208
+ python-versions = ">=2.7"
1209
+
1210
+ [package.extras]
1211
+ dev = ["Django (>=1.11)", "nose2", "tox", "check-manifest", "coverage", "flake8", "wheel", "zest.releaser", "readme-renderer (<25.0)", "colorama (<=0.4.1)"]
1212
+ doc = ["sphinx", "sphinx-rtd-theme"]
1213
+
1214
+ [[package]]
1215
+ name = "setuptools-scm"
1216
+ version = "7.1.0"
1217
+ description = "the blessed package to manage your versions by scm tags"
1218
+ category = "main"
1219
+ optional = false
1220
+ python-versions = ">=3.7"
1221
+
1222
+ [package.dependencies]
1223
+ packaging = ">=20.0"
1224
+ tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
1225
+ typing-extensions = "*"
1226
+
1227
+ [package.extras]
1228
+ test = ["pytest (>=6.2)", "virtualenv (>20)"]
1229
+ toml = ["setuptools (>=42)"]
1230
+
1231
+ [[package]]
1232
+ name = "six"
1233
+ version = "1.16.0"
1234
+ description = "Python 2 and 3 compatibility utilities"
1235
+ category = "main"
1236
+ optional = false
1237
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
1238
+
1239
+ [[package]]
1240
+ name = "sniffio"
1241
+ version = "1.3.0"
1242
+ description = "Sniff out which async library your code is running under"
1243
+ category = "main"
1244
+ optional = false
1245
+ python-versions = ">=3.7"
1246
+
1247
+ [[package]]
1248
+ name = "starlette"
1249
+ version = "0.26.1"
1250
+ description = "The little ASGI library that shines."
1251
+ category = "main"
1252
+ optional = false
1253
+ python-versions = ">=3.7"
1254
+
1255
+ [package.dependencies]
1256
+ anyio = ">=3.4.0,<5"
1257
+
1258
+ [package.extras]
1259
+ full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"]
1260
+
1261
+ [[package]]
1262
+ name = "tenacity"
1263
+ version = "8.2.2"
1264
+ description = "Retry code until it succeeds"
1265
+ category = "main"
1266
+ optional = false
1267
+ python-versions = ">=3.6"
1268
+
1269
+ [package.extras]
1270
+ doc = ["reno", "sphinx", "tornado (>=4.5)"]
1271
+
1272
+ [[package]]
1273
+ name = "tensorboard"
1274
+ version = "2.11.2"
1275
+ description = "TensorBoard lets you watch Tensors Flow"
1276
+ category = "main"
1277
+ optional = false
1278
+ python-versions = ">=3.7"
1279
+
1280
+ [package.dependencies]
1281
+ absl-py = ">=0.4"
1282
+ google-auth = ">=1.6.3,<3"
1283
+ google-auth-oauthlib = ">=0.4.1,<0.5"
1284
+ grpcio = ">=1.24.3"
1285
+ markdown = ">=2.6.8"
1286
+ numpy = ">=1.12.0"
1287
+ protobuf = ">=3.9.2,<4"
1288
+ requests = ">=2.21.0,<3"
1289
+ tensorboard-data-server = ">=0.6.0,<0.7.0"
1290
+ tensorboard-plugin-wit = ">=1.6.0"
1291
+ werkzeug = ">=1.0.1"
1292
+
1293
+ [[package]]
1294
+ name = "tensorboard-data-server"
1295
+ version = "0.6.1"
1296
+ description = "Fast data loading for TensorBoard"
1297
+ category = "main"
1298
+ optional = false
1299
+ python-versions = ">=3.6"
1300
+
1301
+ [[package]]
1302
+ name = "tensorboard-plugin-wit"
1303
+ version = "1.8.1"
1304
+ description = "What-If Tool TensorBoard plugin."
1305
+ category = "main"
1306
+ optional = false
1307
+ python-versions = "*"
1308
+
1309
+ [[package]]
1310
+ name = "tomli"
1311
+ version = "2.0.1"
1312
+ description = "A lil' TOML parser"
1313
+ category = "main"
1314
+ optional = false
1315
+ python-versions = ">=3.7"
1316
+
1317
+ [[package]]
1318
+ name = "toolz"
1319
+ version = "0.12.0"
1320
+ description = "List processing tools and functional utilities"
1321
+ category = "main"
1322
+ optional = false
1323
+ python-versions = ">=3.5"
1324
+
1325
+ [[package]]
1326
+ name = "torch"
1327
+ version = "1.13.1"
1328
+ description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
1329
+ category = "main"
1330
+ optional = false
1331
+ python-versions = ">=3.7.0"
1332
+
1333
+ [package.dependencies]
1334
+ nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""}
1335
+ nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
1336
+ nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
1337
+ nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""}
1338
+ typing-extensions = "*"
1339
+
1340
+ [package.extras]
1341
+ opt-einsum = ["opt-einsum (>=3.3)"]
1342
+
1343
+ [[package]]
1344
+ name = "torchmetrics"
1345
+ version = "0.11.0"
1346
+ description = "PyTorch native Metrics"
1347
+ category = "main"
1348
+ optional = false
1349
+ python-versions = ">=3.7"
1350
+
1351
+ [package.dependencies]
1352
+ numpy = ">=1.17.2"
1353
+ packaging = "*"
1354
+ torch = ">=1.8.1"
1355
+
1356
+ [package.extras]
1357
+ all = ["pystoi", "torchvision (>=0.8)", "pycocotools", "scipy", "torch-fidelity", "torchvision", "lpips", "pytorch-lightning (>=1.5)", "transformers (>=4.10.0)", "regex (>=2021.9.24)", "nltk (>=3.6)", "tqdm (>=4.41.0)"]
1358
+ audio = ["pystoi"]
1359
+ detection = ["torchvision (>=0.8)", "pycocotools"]
1360
+ docs = ["sphinx-autodoc-typehints (>=1.0)", "nbsphinx (>=0.8)", "docutils (>=0.16)", "sphinx-togglebutton (>=0.2)", "pandoc (>=1.0)", "myst-parser", "sphinx-paramlinks (>=0.5.1)", "sphinxcontrib-fulltoc (>=1.0)", "sphinxcontrib-mockautodoc", "sphinx-copybutton (>=0.3)", "sphinx (>=4.0,<5.0)"]
1361
+ image = ["scipy", "torch-fidelity", "torchvision", "lpips"]
1362
+ integrate = ["pytorch-lightning (>=1.5)"]
1363
+ multimodal = ["transformers (>=4.10.0)"]
1364
+ test = ["types-protobuf", "rouge-score (>=0.0.4)", "bert-score (==0.3.10)", "requests", "mir-eval (>=0.6)", "jiwer (>=2.3.0)", "scikit-learn (>1.0,<1.1.1)", "check-manifest", "types-tabulate", "pytest-timeout", "types-emoji", "pycocotools", "coverage (>5.2)", "pytest (>=6.0.0,<7.0.0)", "types-six", "kornia (>=0.6.7)", "phmdoctest (>=1.1.1)", "pandas", "pytest-cov (>2.10)", "cloudpickle (>=1.3)", "pre-commit (>=1.0)", "scipy", "psutil", "mypy (==0.982)", "types-requests", "pytest-rerunfailures (>=10.0)", "types-pyyaml", "types-setuptools", "sacrebleu (>=2.0.0)", "netcal", "pytorch-msssim (==0.2.1)", "transformers (>4.4.0)", "fast-bss-eval (>=0.1.0)", "fire", "scikit-image (>0.17.1)", "dython", "torch-complex", "pytest-doctestplus (>=0.9.0)", "huggingface-hub (<0.7)", "pypesq (>1.2)"]
1365
+ text = ["regex (>=2021.9.24)", "nltk (>=3.6)", "tqdm (>=4.41.0)"]
1366
+
1367
+ [[package]]
1368
+ name = "torchvision"
1369
+ version = "0.14.1"
1370
+ description = "image and video datasets and models for torch deep learning"
1371
+ category = "main"
1372
+ optional = false
1373
+ python-versions = ">=3.7"
1374
+
1375
+ [package.dependencies]
1376
+ numpy = "*"
1377
+ pillow = ">=5.3.0,<8.3.0 || >=8.4.0"
1378
+ requests = "*"
1379
+ torch = "1.13.1"
1380
+ typing-extensions = "*"
1381
+
1382
+ [package.extras]
1383
+ scipy = ["scipy"]
1384
+
1385
+ [[package]]
1386
+ name = "tqdm"
1387
+ version = "4.65.0"
1388
+ description = "Fast, Extensible Progress Meter"
1389
+ category = "main"
1390
+ optional = false
1391
+ python-versions = ">=3.7"
1392
+
1393
+ [package.dependencies]
1394
+ colorama = {version = "*", markers = "platform_system == \"Windows\""}
1395
+
1396
+ [package.extras]
1397
+ dev = ["py-make (>=0.1.0)", "twine", "wheel"]
1398
+ notebook = ["ipywidgets (>=6)"]
1399
+ slack = ["slack-sdk"]
1400
+ telegram = ["requests"]
1401
+
1402
+ [[package]]
1403
+ name = "typing-extensions"
1404
+ version = "4.5.0"
1405
+ description = "Backported and Experimental Type Hints for Python 3.7+"
1406
+ category = "main"
1407
+ optional = false
1408
+ python-versions = ">=3.7"
1409
+
1410
+ [[package]]
1411
+ name = "tzdata"
1412
+ version = "2023.3"
1413
+ description = "Provider of IANA time zone data"
1414
+ category = "main"
1415
+ optional = false
1416
+ python-versions = ">=2"
1417
+
1418
+ [[package]]
1419
+ name = "uc-micro-py"
1420
+ version = "1.0.1"
1421
+ description = "Micro subset of unicode data files for linkify-it-py projects."
1422
+ category = "main"
1423
+ optional = false
1424
+ python-versions = ">=3.6"
1425
+
1426
+ [package.extras]
1427
+ test = ["coverage", "pytest", "pytest-cov"]
1428
+
1429
+ [[package]]
1430
+ name = "uritemplate"
1431
+ version = "4.1.1"
1432
+ description = "Implementation of RFC 6570 URI Templates"
1433
+ category = "main"
1434
+ optional = false
1435
+ python-versions = ">=3.6"
1436
+
1437
+ [[package]]
1438
+ name = "urllib3"
1439
+ version = "1.26.15"
1440
+ description = "HTTP library with thread-safe connection pooling, file post, and more."
1441
+ category = "main"
1442
+ optional = false
1443
+ python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
1444
+
1445
+ [package.extras]
1446
+ brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"]
1447
+ secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "urllib3-secure-extra", "ipaddress"]
1448
+ socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
1449
+
1450
+ [[package]]
1451
+ name = "uvicorn"
1452
+ version = "0.21.1"
1453
+ description = "The lightning-fast ASGI server."
1454
+ category = "main"
1455
+ optional = false
1456
+ python-versions = ">=3.7"
1457
+
1458
+ [package.dependencies]
1459
+ click = ">=7.0"
1460
+ h11 = ">=0.8"
1461
+
1462
+ [package.extras]
1463
+ standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
1464
+
1465
+ [[package]]
1466
+ name = "websockets"
1467
+ version = "11.0.1"
1468
+ description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
1469
+ category = "main"
1470
+ optional = false
1471
+ python-versions = ">=3.7"
1472
+
1473
+ [[package]]
1474
+ name = "werkzeug"
1475
+ version = "2.2.3"
1476
+ description = "The comprehensive WSGI web application library."
1477
+ category = "main"
1478
+ optional = false
1479
+ python-versions = ">=3.7"
1480
+
1481
+ [package.dependencies]
1482
+ MarkupSafe = ">=2.1.1"
1483
+
1484
+ [package.extras]
1485
+ watchdog = ["watchdog"]
1486
+
1487
+ [[package]]
1488
+ name = "wget"
1489
+ version = "3.2"
1490
+ description = "pure python download utility"
1491
+ category = "main"
1492
+ optional = false
1493
+ python-versions = "*"
1494
+
1495
+ [[package]]
1496
+ name = "yarl"
1497
+ version = "1.8.2"
1498
+ description = "Yet another URL library"
1499
+ category = "main"
1500
+ optional = false
1501
+ python-versions = ">=3.7"
1502
+
1503
+ [package.dependencies]
1504
+ idna = ">=2.0"
1505
+ multidict = ">=4.0"
1506
+
1507
+ [metadata]
1508
+ lock-version = "1.1"
1509
+ python-versions = ">=3.10,<3.12"
1510
+ content-hash = "17cec1f61fed3b070c0b744eeecc9dbaed1ea06d758238ac84f108545ab14a21"
1511
+
1512
+ [metadata.files]
1513
+ absl-py = []
1514
+ aiofiles = []
1515
+ aiohttp = []
1516
+ aiosignal = []
1517
+ altair = []
1518
+ antlr4-python3-runtime = []
1519
+ anyio = []
1520
+ async-timeout = []
1521
+ attrs = []
1522
+ cachetools = []
1523
+ certifi = []
1524
+ charset-normalizer = []
1525
+ click = []
1526
+ colorama = []
1527
+ contourpy = []
1528
+ cycler = []
1529
+ earthengine-api = []
1530
+ ee-extra = []
1531
+ entrypoints = []
1532
+ fastapi = []
1533
+ ffmpy = []
1534
+ filelock = []
1535
+ fonttools = []
1536
+ frozenlist = []
1537
+ fsspec = []
1538
+ google-api-core = []
1539
+ google-api-python-client = []
1540
+ google-auth = []
1541
+ google-auth-httplib2 = []
1542
+ google-auth-oauthlib = []
1543
+ google-cloud-core = []
1544
+ google-cloud-storage = []
1545
+ google-crc32c = []
1546
+ google-resumable-media = []
1547
+ googleapis-common-protos = []
1548
+ gradio = []
1549
+ gradio-client = []
1550
+ grpcio = []
1551
+ h11 = []
1552
+ httpcore = []
1553
+ httplib2 = []
1554
+ httpx = []
1555
+ huggingface-hub = []
1556
+ hydra-client = []
1557
+ hydra-core = []
1558
+ idna = []
1559
+ jinja2 = []
1560
+ jsonschema = []
1561
+ kiwisolver = []
1562
+ lightning-utilities = []
1563
+ linkify-it-py = []
1564
+ markdown = []
1565
+ markdown-it-py = []
1566
+ markupsafe = []
1567
+ matplotlib = []
1568
+ mdit-py-plugins = []
1569
+ mdurl = []
1570
+ multidict = []
1571
+ numpy = []
1572
+ nvidia-cublas-cu11 = []
1573
+ nvidia-cuda-nvrtc-cu11 = []
1574
+ nvidia-cuda-runtime-cu11 = []
1575
+ nvidia-cudnn-cu11 = []
1576
+ oauthlib = []
1577
+ omegaconf = []
1578
+ opencv-python = []
1579
+ orjson = []
1580
+ packaging = []
1581
+ pandas = []
1582
+ pillow = []
1583
+ plotly = []
1584
+ protobuf = []
1585
+ pyasn1 = []
1586
+ pyasn1-modules = []
1587
+ pydantic = []
1588
+ pydub = []
1589
+ pyparsing = []
1590
+ pyrsistent = []
1591
+ python-dateutil = []
1592
+ python-multipart = []
1593
+ pytorch-lightning = []
1594
+ pytz = []
1595
+ pyyaml = []
1596
+ requests = []
1597
+ requests-oauthlib = []
1598
+ rsa = []
1599
+ scipy = []
1600
+ seaborn = []
1601
+ semantic-version = []
1602
+ setuptools-scm = []
1603
+ six = []
1604
+ sniffio = []
1605
+ starlette = []
1606
+ tenacity = []
1607
+ tensorboard = []
1608
+ tensorboard-data-server = []
1609
+ tensorboard-plugin-wit = []
1610
+ tomli = []
1611
+ toolz = []
1612
+ torch = []
1613
+ torchmetrics = []
1614
+ torchvision = []
1615
+ tqdm = []
1616
+ typing-extensions = []
1617
+ tzdata = []
1618
+ uc-micro-py = []
1619
+ uritemplate = []
1620
+ urllib3 = []
1621
+ uvicorn = []
1622
+ websockets = []
1623
+ werkzeug = []
1624
+ wget = []
1625
+ yarl = []
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "cv_app"
3
+ version = "0.1.0"
4
+ description = ""
5
+ authors = ["Your Name <[email protected]>"]
6
+
7
+ [tool.poetry.dependencies]
8
+ python = ">=3.10,<3.12"
9
+ torch = "1.13.1"
10
+ tensorboard = "2.11.2"
11
+ pytorch-lightning = "1.9.0"
12
+ torchmetrics = "0.11.0"
13
+ Pillow = "8.4.0"
14
+ torchvision = "0.14.1"
15
+ matplotlib = "^3.7.1"
16
+ hydra-client = "0.5.1"
17
+ hydra-core = "1.3.1"
18
+ wget = "^3.2"
19
+ scipy = "^1.10.1"
20
+ seaborn = "^0.12.2"
21
+ earthengine-api = "0.1.338"
22
+ ee-extra = "0.0.15"
23
+ gradio = "^3.27.0"
24
+ opencv-python = "^4.7.0"
25
+ plotly = "^5.14.1"
26
+
27
+ [tool.poetry.dev-dependencies]
28
+
29
+ [build-system]
30
+ requires = ["poetry-core>=1.0.0"]
31
+ build-backend = "poetry.core.masonry.api"
requirements.txt ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ aiohttp==3.8.3
3
+ aiosignal==1.3.1
4
+ antlr4-python3-runtime==4.9.3
5
+ appdirs==1.4.4
6
+ argh==0.26.2
7
+ async-timeout==4.0.2
8
+ atomicwrites==1.4.0
9
+ attrs==19.3.0
10
+ backports.weakref==1.0.post1
11
+ bkcharts==0.2
12
+ black==19.10b0
13
+ boto==2.49.0
14
+ bqplot==0.12.36
15
+ branca==0.6.0
16
+ brotlipy==0.7.0
17
+ cachetools==5.3.0
18
+ certifi==2021.10.8
19
+ click==8.0.3
20
+ colour==0.1.5
21
+ comtypes==1.1.10
22
+ cycler==0.10.0
23
+ cytoolz==0.11.0
24
+ daal4py==2021.3.0
25
+ dask==2021.10.0
26
+ earthengine-api==0.1.338
27
+ ee-extra==0.0.15
28
+ eerepr==0.0.4
29
+ entrypoints==0.3
30
+ et-xmlfile==1.1.0
31
+ export==0.2.0
32
+ ffmpeg-python==0.2.0
33
+ folium==0.14.0
34
+ fonttools==4.25.0
35
+ frozenlist==1.3.3
36
+ gdown==4.6.0
37
+ geeadd==0.5.6
38
+ geemap==0.19.6
39
+ geocoder==1.38.1
40
+ geojson==3.0.0
41
+ google-api-core==2.11.0
42
+ google-api-python-client==2.74.0
43
+ google-auth==2.16.0
44
+ google-auth-httplib2==0.1.0
45
+ google-auth-oauthlib==0.4.6
46
+ google-cloud-core==2.3.2
47
+ google-cloud-storage==2.7.0
48
+ google-crc32c==1.5.0
49
+ google-resumable-media==2.4.1
50
+ googleapis-common-protos==1.58.0
51
+ grpcio==1.51.1
52
+ httplib2==0.21.0
53
+ hydra-client==0.5.1
54
+ hydra-core==1.3.1
55
+ inflection==0.5.1
56
+ ipyevents==2.0.1
57
+ ipyfilechooser==0.6.0
58
+ ipyleaflet==0.17.2
59
+ ipytree==0.2.2
60
+ lightning-utilities==0.6.0.post0
61
+ llvmlite==0.37.0
62
+ locket==0.2.1
63
+ logzero==1.7.0
64
+ Markdown==3.4.1
65
+ mccabe==0.6.1
66
+ mkl-fft==1.3.1
67
+ mkl-service==2.4.0
68
+ mpmath==1.2.1
69
+ multidict==6.0.4
70
+ munkres==1.1.4
71
+ mypy-extensions==0.4.3
72
+ nltk==3.6.5
73
+ oauthlib==3.2.2
74
+ omegaconf==2.3.0
75
+ pathspec==0.7.0
76
+ patsy==0.5.2
77
+ pep8==1.7.1
78
+ Pillow==8.4.0
79
+ pkginfo==1.7.1
80
+ plotly==5.13.0
81
+ ply==3.11
82
+ protobuf==3.20.3
83
+ pyasn1==0.4.8
84
+ pyasn1-modules==0.2.8
85
+ pycosat==0.6.3
86
+ PyCRS==1.0.2
87
+ pycurl==7.44.1
88
+ pyls-spyder==0.4.0
89
+ pyperclip==1.8.2
90
+ pyreadline==2.1
91
+ pyshp==2.3.1
92
+ pytest==6.2.4
93
+ python-box==6.1.0
94
+ python-lsp-jsonrpc==1.0.0
95
+ python-lsp-server==1.2.4
96
+ pytorch-lightning==1.9.0
97
+ pytz==2021.3
98
+ PyYAML==6.0
99
+ ratelim==0.1.6
100
+ requests-oauthlib==1.3.1
101
+ rsa==4.9
102
+ sankee==0.2.1
103
+ scikit-image==0.18.3
104
+ scooby==0.7.1
105
+ simplegeneric==0.8.1
106
+ Sphinx==4.2.0
107
+ statsmodels==0.12.2
108
+ tables==3.6.1
109
+ tenacity==8.1.0
110
+ tensorboard==2.11.2
111
+ tensorboard-data-server==0.6.1
112
+ tensorboard-plugin-wit==1.8.1
113
+ terminado==0.9.4
114
+ torch==1.13.1
115
+ torchaudio==0.13.1
116
+ torchmetrics==0.11.0
117
+ torchvision==0.14.1
118
+ traittypes==0.2.1
119
+ typing_extensions==4.4.0
120
+ unicodecsv==0.14.1
121
+ uritemplate==4.1.1
122
+ urllib3==1.26.7
123
+ webencodings==0.5.1
124
+ wget==3.2
125
+ whitebox==2.2.0
126
+ whiteboxgui==2.2.0
127
+ win-unicode-console==0.5
128
+ wincertstore==0.2
129
+ xlwt==1.3.0
130
+ xyzservices==2022.9.0
131
+ yarl==1.8.2
132
+ zict==2.0.0
133
+ zope.event==4.5.0