Spaces:
Runtime error
Runtime error
Commit
·
5c718d1
0
Parent(s):
first commit
Browse files- .gitignore +86 -0
- Dockerfile +15 -0
- README.md +40 -0
- biomap/.gitignore +6 -0
- biomap/.private-key.json +12 -0
- biomap/app.py +110 -0
- biomap/configs/my_train_config.yml +197 -0
- biomap/data.py +584 -0
- biomap/dataset_generator/__init__.py +6 -0
- biomap/dataset_generator/data_loader.py +356 -0
- biomap/dino/utils.py +619 -0
- biomap/dino/vision_transformer.py +314 -0
- biomap/helper.py +179 -0
- biomap/inference.py +261 -0
- biomap/label.png +0 -0
- biomap/model.py +453 -0
- biomap/modules.py +472 -0
- biomap/output/img.png +0 -0
- biomap/output/img_6.png +0 -0
- biomap/output/label.png +0 -0
- biomap/output/labeled_img.png +0 -0
- biomap/plot_functions.py +778 -0
- biomap/train.py +267 -0
- biomap/unet.py +80 -0
- biomap/utils.py +390 -0
- biomap/utils_gee.py +157 -0
- poetry.lock +1625 -0
- pyproject.toml +31 -0
- requirements.txt +133 -0
.gitignore
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
|
5 |
+
# C extensions
|
6 |
+
*.so
|
7 |
+
|
8 |
+
# Distribution / packaging
|
9 |
+
.Python
|
10 |
+
env/
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
*.egg-info/
|
23 |
+
.installed.cfg
|
24 |
+
*.egg
|
25 |
+
|
26 |
+
# PyInstaller
|
27 |
+
# Usually these files are written by a python script from a template
|
28 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
29 |
+
*.manifest
|
30 |
+
*.spec
|
31 |
+
|
32 |
+
# Installer logs
|
33 |
+
pip-log.txt
|
34 |
+
pip-delete-this-directory.txt
|
35 |
+
|
36 |
+
# Unit test / coverage reports
|
37 |
+
htmlcov/
|
38 |
+
.tox/
|
39 |
+
.coverage
|
40 |
+
.coverage.*
|
41 |
+
.cache
|
42 |
+
nosetests.xml
|
43 |
+
coverage.xml
|
44 |
+
*.cover
|
45 |
+
|
46 |
+
# Translations
|
47 |
+
*.mo
|
48 |
+
*.pot
|
49 |
+
|
50 |
+
# Django stuff:
|
51 |
+
*.log
|
52 |
+
|
53 |
+
# Sphinx documentation
|
54 |
+
docs/_build/
|
55 |
+
|
56 |
+
# PyBuilder
|
57 |
+
target/
|
58 |
+
|
59 |
+
# DotEnv configuration
|
60 |
+
.env
|
61 |
+
|
62 |
+
# Database
|
63 |
+
*.db
|
64 |
+
*.rdb
|
65 |
+
|
66 |
+
# Pycharm
|
67 |
+
.idea
|
68 |
+
|
69 |
+
# VS Code
|
70 |
+
.vscode/
|
71 |
+
|
72 |
+
# Spyder
|
73 |
+
.spyproject/
|
74 |
+
|
75 |
+
# Jupyter NB Checkpoints
|
76 |
+
.ipynb_checkpoints/
|
77 |
+
|
78 |
+
# Mac OS-specific storage files
|
79 |
+
.DS_Store
|
80 |
+
|
81 |
+
# vim
|
82 |
+
*.swp
|
83 |
+
*.swo
|
84 |
+
|
85 |
+
# Mypy cache
|
86 |
+
.mypy_cache/
|
Dockerfile
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9
|
2 |
+
COPY requirements.txt /app/requirements.txt
|
3 |
+
WORKDIR /app
|
4 |
+
RUN pip install seaborn
|
5 |
+
RUN pip install gradio
|
6 |
+
RUN pip install datetime
|
7 |
+
RUN pip install numpy
|
8 |
+
RUN pip install opencv-python
|
9 |
+
RUN apt-get update
|
10 |
+
RUN apt-get install ffmpeg libsm6 libxext6 -y
|
11 |
+
RUN pip install -r requirements.txt
|
12 |
+
COPY . /app
|
13 |
+
# EXPOSE 7860
|
14 |
+
CMD python app.py
|
15 |
+
# hello world
|
README.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Welcome to the project inno-satellite-images-segmentation-gan
|
2 |
+

|
3 |
+
|
4 |
+
- **Project name**: inno-satellite-images-segmentation-gan
|
5 |
+
- **Library name**: library
|
6 |
+
- **Authors**: Ekimetrics
|
7 |
+
- **Description**: Segmenting satellite images in a large scale is challenging because grondtruth labels are spurious for medium resolution images (Sentinel 2). We want to improve our algorithm either with data augmentation from a GAN, or to correct or adjust Corine labels.
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
## Project Structure
|
12 |
+
```
|
13 |
+
- library/ # Your python library
|
14 |
+
- data/
|
15 |
+
- raw/
|
16 |
+
- processed/
|
17 |
+
- docs/
|
18 |
+
- tests/ # Where goes each unitary test in your folder
|
19 |
+
- scripts/ # Where each automation script will go
|
20 |
+
- requirements.txt # Where you should put the libraries version used in your library
|
21 |
+
```
|
22 |
+
|
23 |
+
|
24 |
+
## Branch strategy
|
25 |
+
TBD
|
26 |
+
|
27 |
+
|
28 |
+
## Ethics checklist
|
29 |
+
TBD
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
## Starter package
|
34 |
+
This project has been created using the Ekimetrics Python Starter Package to enforce best coding practices, reusability and industrialization. <br>
|
35 |
+
If you have any questions please reach out to the inno team and [Théo Alves Da Costa](mailto:[email protected])
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
biomap/.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#wsl
|
2 |
+
*.Zone.Identifier
|
3 |
+
|
4 |
+
#python
|
5 |
+
*__pycache__
|
6 |
+
|
biomap/.private-key.json
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"type": "service_account",
|
3 |
+
"project_id": "cvimg-377115",
|
4 |
+
"private_key_id": "a162152bd26f4bcc287c44b130109892b5517875",
|
5 |
+
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDCr/zwOTwyVVdF\nk1cKcdju9jeDWceVJ/zj73b5IS8IbQJbFGWSjaL6Ft1HfJ3rSdGbj+Xy26jY9OFJ\n4rIhpY0M0cCWpYo+gqS9p4JAL6lHqZvSnkRTglpx9QYOT8o9ibCWhMVVAPH71QZ/\n4BEfGC6s2+zdEn+cbCkGIqLLhZTq655kDOaGSycwV/bk+TOLI/An4gjMoEIimhsD\nS6TRmqZnQGoI6m6aj3xPZGVMkid3I37h+BOC64YjKeXnpAhTQ4LbpQz1BxvIKDt6\ncJ1FBOmdSvUC+dLEGMN3yijpJZ74nXUnSVbwYYt3K8Kz2PtmgDYvswER53NlrMoW\n1AF9ImDFAgMBAAECggEAASO7IXWBo90d6CiEdhQF9uFy2S5d9ol9N6EBl+JHOuwB\nwne4tSZJ3jT/Rus7wvX67tXI9dgf3fIgjsv92NLnwEn1Wq/xQppMm8iyK1DKU3vH\n8xrvf8iG048ojGqQXqf0ZEUoWd+/YDGZ2qNuZjmVgwZKwF2h2pcnQ25uIvWdYHrb\n3XhYLDAROVTTtyscYcl8UKmAZ35moVVBQxdakGunYg6o/s6rESRbc+gCyqHR5v+r\nCl3Z4XEKDdukIVI72Ybk0F8eZpQztN97uzK/zm9jl4NmAPXrnWLEwuJdwdm1cWUF\n/LTTuNPmRzCm7IGUpkx0AKEs6s0BRbJbwlZaj4QVJwKBgQDjb2rSO6kRLRHyv+4w\ny/OLmqOrMY7fpSCj0mH41GhiZUaoZqhDznmuhqEjo1kVipfuW94Fn5NsGmWpbmEC\nJlObUEg1umX/ceOJrtRdY3AQMSQXR6u7oc2mTgj3Opd0V1L1Lopj4Ijj43ARg/fU\nu4RnrCGHcXXzT2LCchY0ZhLg3wKBgQDbI6bzt/RNW8+IGKCvLLi41bxM/9r83GNO\nQI4a6yTT09N11okjP9h00JKYBgU3fYivA1aBloFB4kOYaBzomfWSZHEyyFWCr9y0\ndGyIDbfUaI/jFx2CaKomLnPDF5LA3IWHAsTRZ/c1JGhiOUseEq/TR0cJAo69kgf0\nkVmoGjo+2wKBgQCo7crkGJg9P8LDEbgz2mktWlETCR5cE2SpCczna62U2DChSI7W\nvng3H5x0whGbJHQxAV9pwdtYQksci/XWCO20wO7BqY+1KrydOZRXQVKtVDLAb+Wo\n2kfLrM6QA58XNP1TS5xTDyXeTsKg3+qmwhlYf8vvtGCttltenirMBL0k9QKBgFpL\nanNqDOQDPJQbcbo8dzDSAPDJS/Z86P5JY0R8N4SA99TKPV+k4w/fEUhK0sN2mmdi\nvLZQyZnYHXojDCZbqfBUKsB+A54B0LMadc3puSFwpDkyQRqG/fUVluWARRvqwapL\n3cVbTWU8RzaR3P3bPU+VQxPXVfGOxnBjo8m8ZNuZAoGBANTC20T9rZ9Won9FCbi3\nSMkGY59smx19CdytQ2rjGFOEeAVMVotP5viXFuKfv5g2E/kvyJUjuOoulA3dxddN\nQzXnOIT3dlyBmvXkHJHUIKiidyuX4JqQFdPTAmkt6KaTceRNb7VN1OqIk1AJ1SDb\nkGxerLg4WuGfSqOIV0Wk4cLI\n-----END PRIVATE KEY-----\n",
|
6 |
+
"client_email": "[email protected]",
|
7 |
+
"client_id": "115144831673857322488",
|
8 |
+
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
9 |
+
"token_uri": "https://oauth2.googleapis.com/token",
|
10 |
+
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
11 |
+
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/cvimg-355%40cvimg-377115.iam.gserviceaccount.com"
|
12 |
+
}
|
biomap/app.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from plot_functions import *
|
2 |
+
import hydra
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from model import LitUnsupervisedSegmenter
|
6 |
+
from helper import inference_on_location_and_month, inference_on_location
|
7 |
+
from plot_functions import segment_region
|
8 |
+
|
9 |
+
from functools import partial
|
10 |
+
import gradio as gr
|
11 |
+
import logging
|
12 |
+
|
13 |
+
import geopandas as gpd
|
14 |
+
mapbox_access_token = "pk.eyJ1IjoiamVyZW15LWVraW1ldHJpY3MiLCJhIjoiY2xrNjBwNGU2MDRhMjNqbWw0YTJrbnpvNCJ9.poVyIzhJuJmD6ffrL9lm2w"
|
15 |
+
geo_df = gpd.read_file(gpd.datasets.get_path('naturalearth_cities'))
|
16 |
+
|
17 |
+
def get_geomap(long, lat ):
|
18 |
+
|
19 |
+
|
20 |
+
fig = go.Figure(go.Scattermapbox(
|
21 |
+
lat=geo_df.geometry.y,
|
22 |
+
lon=geo_df.geometry.x,
|
23 |
+
mode='markers',
|
24 |
+
marker=go.scattermapbox.Marker(
|
25 |
+
size=14
|
26 |
+
),
|
27 |
+
text=geo_df.name,
|
28 |
+
))
|
29 |
+
|
30 |
+
fig.add_trace(go.Scattermapbox(lat=[lat],
|
31 |
+
lon=[long],
|
32 |
+
mode='markers',
|
33 |
+
marker=go.scattermapbox.Marker(
|
34 |
+
size=14
|
35 |
+
),
|
36 |
+
marker_color="green",
|
37 |
+
text=['Actual position']))
|
38 |
+
|
39 |
+
fig.update_layout(
|
40 |
+
showlegend=False,
|
41 |
+
hovermode='closest',
|
42 |
+
mapbox=dict(
|
43 |
+
accesstoken=mapbox_access_token,
|
44 |
+
center=go.layout.mapbox.Center(
|
45 |
+
lat=lat,
|
46 |
+
lon=long
|
47 |
+
),
|
48 |
+
zoom=3
|
49 |
+
)
|
50 |
+
)
|
51 |
+
|
52 |
+
return fig
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
|
57 |
+
|
58 |
+
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
|
59 |
+
# Initialize hydra with configs
|
60 |
+
#hydra.initialize(config_path="configs", job_name="corine")
|
61 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
62 |
+
logging.info(f"config : {cfg}")
|
63 |
+
# Load the model
|
64 |
+
|
65 |
+
nbclasses = cfg.dir_dataset_n_classes
|
66 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
67 |
+
logging.info(f"Model Initialiazed")
|
68 |
+
|
69 |
+
model_path = "checkpoint/model/model.pt"
|
70 |
+
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
71 |
+
logging.info(f"Model weights Loaded")
|
72 |
+
model.load_state_dict(saved_state_dict)
|
73 |
+
logging.info(f"Model Loaded")
|
74 |
+
# css=".VIDEO video{height: 100%;width:50%;margin:auto};.VIDEO{height: 50%;};.svelte-1vnmhm4{height:auto}"
|
75 |
+
with gr.Blocks() as demo:
|
76 |
+
gr.Markdown("Estimate Biodiversity in the world.")
|
77 |
+
with gr.Tab("Single Image"):
|
78 |
+
with gr.Row():
|
79 |
+
input_map = gr.Plot().style()
|
80 |
+
with gr.Column():
|
81 |
+
input_latitude = gr.Number(label="lattitude", value=2.98)
|
82 |
+
input_longitude = gr.Number(label="longitude", value=48.81)
|
83 |
+
input_date = gr.Textbox(label="start_date", value="2020-03-20")
|
84 |
+
|
85 |
+
single_button = gr.Button("Predict")
|
86 |
+
with gr.Row():
|
87 |
+
raw_image = gr.Image(label = "Localisation visualization")
|
88 |
+
output_image = gr.Image(label = "Labeled visualisation")
|
89 |
+
score_biodiv = gr.Number(label = "Biodiversity score")
|
90 |
+
|
91 |
+
with gr.Tab("TimeLapse"):
|
92 |
+
with gr.Row():
|
93 |
+
input_map_2 = gr.Plot().style()
|
94 |
+
with gr.Row():
|
95 |
+
timelapse_input_latitude = gr.Number(value=2.98, label="Latitude")
|
96 |
+
timelapse_input_longitude = gr.Number(value=48.81, label="Longitude")
|
97 |
+
timelapse_start_date = gr.Textbox(value='2020-05-01', label="Start Date")
|
98 |
+
timelapse_end_date = gr.Textbox(value='2020-06-30', label="End Date")
|
99 |
+
segmentation = gr.CheckboxGroup(choices=['month', 'year', '2months'], value=['month'], label="Select Segmentation Level:")
|
100 |
+
timelapse_button = gr.Button(value="Predict")
|
101 |
+
map = gr.Plot().style()
|
102 |
+
|
103 |
+
demo.load(get_geomap, [input_latitude, input_longitude], input_map)
|
104 |
+
single_button.click(get_geomap, [input_latitude, input_longitude], input_map)
|
105 |
+
single_button.click(partial(inference_on_location_and_month, model), inputs=[input_latitude, input_longitude, input_date], outputs=[raw_image, output_image,score_biodiv])
|
106 |
+
|
107 |
+
demo.load(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
|
108 |
+
timelapse_button.click(get_geomap, [timelapse_input_latitude, timelapse_input_longitude], input_map_2)
|
109 |
+
timelapse_button.click(segment_region, inputs=[timelapse_input_latitude, timelapse_input_longitude, timelapse_start_date, timelapse_end_date,segmentation], outputs=[map])
|
110 |
+
demo.launch(share=True)
|
biomap/configs/my_train_config.yml
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
output_root: '../'
|
2 |
+
pytorch_data_dir: '/home/duong_nguyen/pytorch-data'
|
3 |
+
experiment_name: "unet_7classes"
|
4 |
+
log_dir: "france"
|
5 |
+
# experiment_name: "unet"
|
6 |
+
# log_dir: "potsdam"
|
7 |
+
azureml_logging: False
|
8 |
+
submitting_to_aml: False
|
9 |
+
full_name: ~
|
10 |
+
|
11 |
+
# Loader params
|
12 |
+
num_workers: 24
|
13 |
+
max_steps: 80000
|
14 |
+
batch_size: 16
|
15 |
+
|
16 |
+
num_neighbors: 7
|
17 |
+
dataset_name: "directory"
|
18 |
+
# dataset_name: "potsdam"
|
19 |
+
|
20 |
+
# Used if dataset_name is "directory"
|
21 |
+
dir_dataset_name: "corine"
|
22 |
+
dir_dataset_n_classes: 7
|
23 |
+
|
24 |
+
has_labels: False
|
25 |
+
# crop_type: "five"
|
26 |
+
crop_type: ~
|
27 |
+
crop_ratio: .5
|
28 |
+
res: 224
|
29 |
+
loader_crop_type: "center"
|
30 |
+
|
31 |
+
# Model Params
|
32 |
+
extra_clusters: 0
|
33 |
+
use_true_labels: False
|
34 |
+
use_recalibrator: False
|
35 |
+
model_type: "vit_small"
|
36 |
+
arch: "dino"
|
37 |
+
use_fit_model: False
|
38 |
+
dino_feat_type: "feat"
|
39 |
+
projection_type: "nonlinear"
|
40 |
+
#projection_type: linear
|
41 |
+
dino_patch_size: 8
|
42 |
+
granularity: 1
|
43 |
+
continuous: True
|
44 |
+
dim: 70
|
45 |
+
dropout: True
|
46 |
+
zero_clamp: True
|
47 |
+
|
48 |
+
lr: 5e-4
|
49 |
+
pretrained_weights: ~
|
50 |
+
use_salience: False
|
51 |
+
stabalize: False
|
52 |
+
stop_at_zero: True
|
53 |
+
|
54 |
+
# Feature Contrastive params
|
55 |
+
pointwise: True
|
56 |
+
feature_samples: 11
|
57 |
+
neg_samples: 5
|
58 |
+
aug_alignment_weight: 0.0
|
59 |
+
|
60 |
+
correspondence_weight: 1.0
|
61 |
+
|
62 |
+
|
63 |
+
# # Corine vit small 24/11/22
|
64 |
+
neg_inter_weight: 0.63
|
65 |
+
pos_inter_weight: 0.25
|
66 |
+
pos_intra_weight: 0.67
|
67 |
+
neg_inter_shift: 0.46
|
68 |
+
pos_inter_shift: 0.02
|
69 |
+
pos_intra_shift: 0.08
|
70 |
+
|
71 |
+
# # Corine vit small 11/09/22
|
72 |
+
# neg_inter_weight: 0.63
|
73 |
+
# pos_inter_weight: 0.25
|
74 |
+
# pos_intra_weight: 0.67
|
75 |
+
# neg_inter_shift: 0.46
|
76 |
+
# pos_inter_shift: 0.24
|
77 |
+
# pos_intra_shift: 0.36
|
78 |
+
|
79 |
+
# # IAROA vit small 1/31/22
|
80 |
+
# neg_inter_weight: 0.63
|
81 |
+
# pos_inter_weight: 0.25
|
82 |
+
# pos_intra_weight: 0.67
|
83 |
+
# neg_inter_shift: 0.46
|
84 |
+
# pos_inter_shift: 0.12
|
85 |
+
# pos_intra_shift: 0.18
|
86 |
+
|
87 |
+
# Potsdam vit small 1/31/22
|
88 |
+
# neg_inter_weight: 0.63
|
89 |
+
# pos_inter_weight: 0.25
|
90 |
+
# pos_intra_weight: 0.67
|
91 |
+
# neg_inter_shift: 0.46
|
92 |
+
# pos_inter_shift: 0.02
|
93 |
+
# pos_intra_shift: 0.08
|
94 |
+
|
95 |
+
# Cocostuff27 vit small 1/31/22
|
96 |
+
#neg_inter_weight: 0.63
|
97 |
+
#pos_inter_weight: 0.25
|
98 |
+
#pos_intra_weight: 0.67
|
99 |
+
#neg_inter_shift: 0.66
|
100 |
+
#pos_inter_shift: 0.02
|
101 |
+
#pos_intra_shift: 0.08
|
102 |
+
|
103 |
+
|
104 |
+
## Cocostuff27 10/3 vit_base
|
105 |
+
|
106 |
+
#neg_inter_weight: 0.1538476246415498
|
107 |
+
#pos_inter_weight: 1
|
108 |
+
#pos_intra_weight: 0.1
|
109 |
+
#
|
110 |
+
#neg_inter_shift: 1
|
111 |
+
#pos_inter_shift: 0.2
|
112 |
+
#pos_intra_shift: 0.12
|
113 |
+
|
114 |
+
|
115 |
+
## Cocostuff27 10/3 vit_small
|
116 |
+
#neg_inter_weight: .63
|
117 |
+
#pos_inter_weight: .25
|
118 |
+
#pos_intra_weight: .67
|
119 |
+
#
|
120 |
+
#neg_inter_shift: .16
|
121 |
+
#pos_inter_shift: .02
|
122 |
+
#pos_intra_shift: .08
|
123 |
+
|
124 |
+
|
125 |
+
|
126 |
+
## Cocostuff27 10/3 moco
|
127 |
+
#neg_inter_weight: .63
|
128 |
+
#pos_inter_weight: .25
|
129 |
+
#pos_intra_weight: .67
|
130 |
+
#
|
131 |
+
#neg_inter_shift: .26
|
132 |
+
#pos_inter_shift: .36
|
133 |
+
#pos_intra_shift: .32
|
134 |
+
|
135 |
+
#pos_inter_shift: .12
|
136 |
+
#pos_intra_shift: .18
|
137 |
+
|
138 |
+
## Cocostuff27
|
139 |
+
#neg_inter_weight: .72
|
140 |
+
#pos_inter_weight: .80
|
141 |
+
#pos_intra_weight: .29
|
142 |
+
#
|
143 |
+
#neg_inter_shift: .86
|
144 |
+
#pos_inter_shift: .04
|
145 |
+
#pos_intra_shift: .34
|
146 |
+
|
147 |
+
# Cityscapes 10/3
|
148 |
+
|
149 |
+
# neg_inter_weight: 0.9058762625226623
|
150 |
+
# pos_inter_weight: 0.577453483136995
|
151 |
+
# pos_intra_weight: 1
|
152 |
+
|
153 |
+
# neg_inter_shift: 0.31361241889448443
|
154 |
+
# pos_inter_shift: 0.1754346515479633
|
155 |
+
# pos_intra_shift: 0.45828472207
|
156 |
+
|
157 |
+
|
158 |
+
# Cityscapes
|
159 |
+
#neg_inter_weight: .72
|
160 |
+
#pos_inter_weight: .18
|
161 |
+
#pos_intra_weight: .46
|
162 |
+
#
|
163 |
+
#neg_inter_shift: .25
|
164 |
+
#pos_inter_shift: .20
|
165 |
+
#pos_intra_shift: .25
|
166 |
+
|
167 |
+
|
168 |
+
rec_weight: 0.0
|
169 |
+
repulsion_weight: 0.0
|
170 |
+
|
171 |
+
# CRF Params
|
172 |
+
crf_weight: 0.0
|
173 |
+
alpha: .5
|
174 |
+
beta: .15
|
175 |
+
gamma: .05
|
176 |
+
w1: 10.0
|
177 |
+
w2: 3.0
|
178 |
+
shift: 0.00
|
179 |
+
crf_samples: 1000
|
180 |
+
color_space: "rgb"
|
181 |
+
|
182 |
+
reset_probe_steps: ~
|
183 |
+
|
184 |
+
# Logging params
|
185 |
+
n_images: 5
|
186 |
+
scalar_log_freq: 10
|
187 |
+
checkpoint_freq: 50
|
188 |
+
val_freq: 100
|
189 |
+
hist_freq: 100
|
190 |
+
|
191 |
+
|
192 |
+
hydra:
|
193 |
+
run:
|
194 |
+
dir: "."
|
195 |
+
output_subdir: ~
|
196 |
+
#job_logging: "disabled"
|
197 |
+
#hydra_logging: "disabled"
|
biomap/data.py
ADDED
@@ -0,0 +1,584 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from os.path import join
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch.multiprocessing
|
7 |
+
from PIL import Image
|
8 |
+
from scipy.io import loadmat
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from torchvision.datasets.cityscapes import Cityscapes
|
12 |
+
from torchvision.transforms.functional import to_pil_image
|
13 |
+
from tqdm import tqdm
|
14 |
+
|
15 |
+
|
16 |
+
def bit_get(val, idx):
|
17 |
+
"""Gets the bit value.
|
18 |
+
Args:
|
19 |
+
val: Input value, int or numpy int array.
|
20 |
+
idx: Which bit of the input val.
|
21 |
+
Returns:
|
22 |
+
The "idx"-th bit of input val.
|
23 |
+
"""
|
24 |
+
return (val >> idx) & 1
|
25 |
+
|
26 |
+
|
27 |
+
def create_pascal_label_colormap():
|
28 |
+
"""Creates a label colormap used in PASCAL VOC segmentation benchmark.
|
29 |
+
Returns:
|
30 |
+
A colormap for visualizing segmentation results.
|
31 |
+
"""
|
32 |
+
colormap = np.zeros((512, 3), dtype=int)
|
33 |
+
ind = np.arange(512, dtype=int)
|
34 |
+
|
35 |
+
for shift in reversed(list(range(8))):
|
36 |
+
for channel in range(3):
|
37 |
+
colormap[:, channel] |= bit_get(ind, channel) << shift
|
38 |
+
ind >>= 3
|
39 |
+
|
40 |
+
return colormap
|
41 |
+
|
42 |
+
|
43 |
+
def create_cityscapes_colormap():
|
44 |
+
colors = [(128, 64, 128),
|
45 |
+
(244, 35, 232),
|
46 |
+
(250, 170, 160),
|
47 |
+
(230, 150, 140),
|
48 |
+
(70, 70, 70),
|
49 |
+
(102, 102, 156),
|
50 |
+
(190, 153, 153),
|
51 |
+
(180, 165, 180),
|
52 |
+
(150, 100, 100),
|
53 |
+
(150, 120, 90),
|
54 |
+
(153, 153, 153),
|
55 |
+
(153, 153, 153),
|
56 |
+
(250, 170, 30),
|
57 |
+
(220, 220, 0),
|
58 |
+
(107, 142, 35),
|
59 |
+
(152, 251, 152),
|
60 |
+
(70, 130, 180),
|
61 |
+
(220, 20, 60),
|
62 |
+
(255, 0, 0),
|
63 |
+
(0, 0, 142),
|
64 |
+
(0, 0, 70),
|
65 |
+
(0, 60, 100),
|
66 |
+
(0, 0, 90),
|
67 |
+
(0, 0, 110),
|
68 |
+
(0, 80, 100),
|
69 |
+
(0, 0, 230),
|
70 |
+
(119, 11, 32),
|
71 |
+
(0, 0, 0)]
|
72 |
+
return np.array(colors)
|
73 |
+
|
74 |
+
|
75 |
+
class DirectoryDataset(Dataset):
|
76 |
+
def __init__(self, root, path, image_set, transform, target_transform):
|
77 |
+
super(DirectoryDataset, self).__init__()
|
78 |
+
self.split = image_set
|
79 |
+
self.dir = join(root, path)
|
80 |
+
self.img_dir = join(self.dir, "imgs", self.split)
|
81 |
+
self.label_dir = join(self.dir, "labels", self.split)
|
82 |
+
|
83 |
+
self.transform = transform
|
84 |
+
self.target_transform = target_transform
|
85 |
+
|
86 |
+
self.img_files = np.array(sorted(os.listdir(self.img_dir)))
|
87 |
+
assert len(self.img_files) > 0
|
88 |
+
if os.path.exists(join(self.dir, "labels")):
|
89 |
+
self.label_files = np.array(sorted(os.listdir(self.label_dir)))
|
90 |
+
assert len(self.img_files) == len(self.label_files)
|
91 |
+
else:
|
92 |
+
self.label_files = None
|
93 |
+
self.fine_to_coarse = {0: 0,
|
94 |
+
1: 1,
|
95 |
+
2: 2,
|
96 |
+
3: 3,
|
97 |
+
4: 4,
|
98 |
+
5: 5,
|
99 |
+
6: 6,
|
100 |
+
7: -1,
|
101 |
+
}
|
102 |
+
|
103 |
+
def __getitem__(self, index):
|
104 |
+
image_fn = self.img_files[index]
|
105 |
+
img = Image.open(join(self.img_dir, image_fn))
|
106 |
+
|
107 |
+
if self.label_files is not None:
|
108 |
+
label_fn = self.label_files[index]
|
109 |
+
label = Image.open(join(self.label_dir, label_fn))
|
110 |
+
|
111 |
+
seed = np.random.randint(2147483647)
|
112 |
+
random.seed(seed)
|
113 |
+
torch.manual_seed(seed)
|
114 |
+
img = self.transform(img)
|
115 |
+
|
116 |
+
if self.label_files is not None:
|
117 |
+
random.seed(seed)
|
118 |
+
torch.manual_seed(seed)
|
119 |
+
label = self.target_transform(label)
|
120 |
+
new_label_map = torch.zeros_like(label)
|
121 |
+
for fine, coarse in self.fine_to_coarse.items():
|
122 |
+
new_label_map[label == fine] = coarse
|
123 |
+
label = new_label_map
|
124 |
+
else:
|
125 |
+
label = torch.zeros(img.shape[1], img.shape[2], dtype=torch.int64) - 1
|
126 |
+
|
127 |
+
mask = (label > 0).to(torch.float32)
|
128 |
+
return img, label, mask
|
129 |
+
|
130 |
+
|
131 |
+
def __len__(self):
|
132 |
+
return len(self.img_files)
|
133 |
+
|
134 |
+
|
135 |
+
class Potsdam(Dataset):
|
136 |
+
def __init__(self, root, image_set, transform, target_transform, coarse_labels):
|
137 |
+
super(Potsdam, self).__init__()
|
138 |
+
self.split = image_set
|
139 |
+
self.root = os.path.join(root, "potsdam")
|
140 |
+
self.transform = transform
|
141 |
+
self.target_transform = target_transform
|
142 |
+
split_files = {
|
143 |
+
"train": ["labelled_train.txt"],
|
144 |
+
"unlabelled_train": ["unlabelled_train.txt"],
|
145 |
+
# "train": ["unlabelled_train.txt"],
|
146 |
+
"val": ["labelled_test.txt"],
|
147 |
+
"train+val": ["labelled_train.txt", "labelled_test.txt"],
|
148 |
+
"all": ["all.txt"]
|
149 |
+
}
|
150 |
+
assert self.split in split_files.keys()
|
151 |
+
|
152 |
+
self.files = []
|
153 |
+
for split_file in split_files[self.split]:
|
154 |
+
with open(join(self.root, split_file), "r") as f:
|
155 |
+
self.files.extend(fn.rstrip() for fn in f.readlines())
|
156 |
+
|
157 |
+
self.coarse_labels = coarse_labels
|
158 |
+
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
|
159 |
+
1: 1, 5: 1, # buildings and clutter
|
160 |
+
2: 2, 3: 2, # vegetation and trees
|
161 |
+
255: -1
|
162 |
+
}
|
163 |
+
|
164 |
+
def __getitem__(self, index):
|
165 |
+
image_id = self.files[index]
|
166 |
+
img = loadmat(join(self.root, "imgs", image_id + ".mat"))["img"]
|
167 |
+
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
|
168 |
+
try:
|
169 |
+
label = loadmat(join(self.root, "gt", image_id + ".mat"))["gt"]
|
170 |
+
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
|
171 |
+
except FileNotFoundError:
|
172 |
+
label = to_pil_image(torch.ones(1, img.height, img.width))
|
173 |
+
|
174 |
+
seed = np.random.randint(2147483647)
|
175 |
+
random.seed(seed)
|
176 |
+
torch.manual_seed(seed)
|
177 |
+
img = self.transform(img)
|
178 |
+
|
179 |
+
random.seed(seed)
|
180 |
+
torch.manual_seed(seed)
|
181 |
+
label = self.target_transform(label).squeeze(0)
|
182 |
+
if self.coarse_labels:
|
183 |
+
new_label_map = torch.zeros_like(label)
|
184 |
+
for fine, coarse in self.fine_to_coarse.items():
|
185 |
+
new_label_map[label == fine] = coarse
|
186 |
+
label = new_label_map
|
187 |
+
|
188 |
+
mask = (label > 0).to(torch.float32)
|
189 |
+
return img, label, mask
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return len(self.files)
|
193 |
+
|
194 |
+
|
195 |
+
class PotsdamRaw(Dataset):
|
196 |
+
def __init__(self, root, image_set, transform, target_transform, coarse_labels):
|
197 |
+
super(PotsdamRaw, self).__init__()
|
198 |
+
self.split = image_set
|
199 |
+
self.root = os.path.join(root, "potsdamraw", "processed")
|
200 |
+
self.transform = transform
|
201 |
+
self.target_transform = target_transform
|
202 |
+
self.files = []
|
203 |
+
for im_num in range(38):
|
204 |
+
for i_h in range(15):
|
205 |
+
for i_w in range(15):
|
206 |
+
self.files.append("{}_{}_{}.mat".format(im_num, i_h, i_w))
|
207 |
+
|
208 |
+
self.coarse_labels = coarse_labels
|
209 |
+
self.fine_to_coarse = {0: 0, 4: 0, # roads and cars
|
210 |
+
1: 1, 5: 1, # buildings and clutter
|
211 |
+
2: 2, 3: 2, # vegetation and trees
|
212 |
+
255: -1
|
213 |
+
}
|
214 |
+
|
215 |
+
def __getitem__(self, index):
|
216 |
+
image_id = self.files[index]
|
217 |
+
img = loadmat(join(self.root, "imgs", image_id))["img"]
|
218 |
+
img = to_pil_image(torch.from_numpy(img).permute(2, 0, 1)[:3]) # TODO add ir channel back
|
219 |
+
try:
|
220 |
+
label = loadmat(join(self.root, "gt", image_id))["gt"]
|
221 |
+
label = to_pil_image(torch.from_numpy(label).unsqueeze(-1).permute(2, 0, 1))
|
222 |
+
except FileNotFoundError:
|
223 |
+
label = to_pil_image(torch.ones(1, img.height, img.width))
|
224 |
+
|
225 |
+
seed = np.random.randint(2147483647)
|
226 |
+
random.seed(seed)
|
227 |
+
torch.manual_seed(seed)
|
228 |
+
img = self.transform(img)
|
229 |
+
|
230 |
+
random.seed(seed)
|
231 |
+
torch.manual_seed(seed)
|
232 |
+
label = self.target_transform(label).squeeze(0)
|
233 |
+
if self.coarse_labels:
|
234 |
+
new_label_map = torch.zeros_like(label)
|
235 |
+
for fine, coarse in self.fine_to_coarse.items():
|
236 |
+
new_label_map[label == fine] = coarse
|
237 |
+
label = new_label_map
|
238 |
+
|
239 |
+
mask = (label > 0).to(torch.float32)
|
240 |
+
return img, label, mask
|
241 |
+
|
242 |
+
def __len__(self):
|
243 |
+
return len(self.files)
|
244 |
+
|
245 |
+
|
246 |
+
class Coco(Dataset):
|
247 |
+
def __init__(self, root, image_set, transform, target_transform,
|
248 |
+
coarse_labels, exclude_things, subset=None):
|
249 |
+
super(Coco, self).__init__()
|
250 |
+
self.split = image_set
|
251 |
+
self.root = join(root, "cocostuff")
|
252 |
+
self.coarse_labels = coarse_labels
|
253 |
+
self.transform = transform
|
254 |
+
self.label_transform = target_transform
|
255 |
+
self.subset = subset
|
256 |
+
self.exclude_things = exclude_things
|
257 |
+
|
258 |
+
if self.subset is None:
|
259 |
+
self.image_list = "Coco164kFull_Stuff_Coarse.txt"
|
260 |
+
elif self.subset == 6: # IIC Coarse
|
261 |
+
self.image_list = "Coco164kFew_Stuff_6.txt"
|
262 |
+
elif self.subset == 7: # IIC Fine
|
263 |
+
self.image_list = "Coco164kFull_Stuff_Coarse_7.txt"
|
264 |
+
|
265 |
+
assert self.split in ["train", "val", "train+val"]
|
266 |
+
split_dirs = {
|
267 |
+
"train": ["train2017"],
|
268 |
+
"val": ["val2017"],
|
269 |
+
"train+val": ["train2017", "val2017"]
|
270 |
+
}
|
271 |
+
|
272 |
+
self.image_files = []
|
273 |
+
self.label_files = []
|
274 |
+
for split_dir in split_dirs[self.split]:
|
275 |
+
with open(join(self.root, "curated", split_dir, self.image_list), "r") as f:
|
276 |
+
img_ids = [fn.rstrip() for fn in f.readlines()]
|
277 |
+
for img_id in img_ids:
|
278 |
+
self.image_files.append(join(self.root, "images", split_dir, img_id + ".jpg"))
|
279 |
+
self.label_files.append(join(self.root, "annotations", split_dir, img_id + ".png"))
|
280 |
+
|
281 |
+
self.fine_to_coarse = {0: 9, 1: 11, 2: 11, 3: 11, 4: 11, 5: 11, 6: 11, 7: 11, 8: 11, 9: 8, 10: 8, 11: 8, 12: 8,
|
282 |
+
13: 8, 14: 8, 15: 7, 16: 7, 17: 7, 18: 7, 19: 7, 20: 7, 21: 7, 22: 7, 23: 7, 24: 7,
|
283 |
+
25: 6, 26: 6, 27: 6, 28: 6, 29: 6, 30: 6, 31: 6, 32: 6, 33: 10, 34: 10, 35: 10, 36: 10,
|
284 |
+
37: 10, 38: 10, 39: 10, 40: 10, 41: 10, 42: 10, 43: 5, 44: 5, 45: 5, 46: 5, 47: 5, 48: 5,
|
285 |
+
49: 5, 50: 5, 51: 2, 52: 2, 53: 2, 54: 2, 55: 2, 56: 2, 57: 2, 58: 2, 59: 2, 60: 2,
|
286 |
+
61: 3, 62: 3, 63: 3, 64: 3, 65: 3, 66: 3, 67: 3, 68: 3, 69: 3, 70: 3, 71: 0, 72: 0,
|
287 |
+
73: 0, 74: 0, 75: 0, 76: 0, 77: 1, 78: 1, 79: 1, 80: 1, 81: 1, 82: 1, 83: 4, 84: 4,
|
288 |
+
85: 4, 86: 4, 87: 4, 88: 4, 89: 4, 90: 4, 91: 17, 92: 17, 93: 22, 94: 20, 95: 20, 96: 22,
|
289 |
+
97: 15, 98: 25, 99: 16, 100: 13, 101: 12, 102: 12, 103: 17, 104: 17, 105: 23, 106: 15,
|
290 |
+
107: 15, 108: 17, 109: 15, 110: 21, 111: 15, 112: 25, 113: 13, 114: 13, 115: 13, 116: 13,
|
291 |
+
117: 13, 118: 22, 119: 26, 120: 14, 121: 14, 122: 15, 123: 22, 124: 21, 125: 21, 126: 24,
|
292 |
+
127: 20, 128: 22, 129: 15, 130: 17, 131: 16, 132: 15, 133: 22, 134: 24, 135: 21, 136: 17,
|
293 |
+
137: 25, 138: 16, 139: 21, 140: 17, 141: 22, 142: 16, 143: 21, 144: 21, 145: 25, 146: 21,
|
294 |
+
147: 26, 148: 21, 149: 24, 150: 20, 151: 17, 152: 14, 153: 21, 154: 26, 155: 15, 156: 23,
|
295 |
+
157: 20, 158: 21, 159: 24, 160: 15, 161: 24, 162: 22, 163: 25, 164: 15, 165: 20, 166: 17,
|
296 |
+
167: 17, 168: 22, 169: 14, 170: 18, 171: 18, 172: 18, 173: 18, 174: 18, 175: 18, 176: 18,
|
297 |
+
177: 26, 178: 26, 179: 19, 180: 19, 181: 24}
|
298 |
+
|
299 |
+
self._label_names = [
|
300 |
+
"ground-stuff",
|
301 |
+
"plant-stuff",
|
302 |
+
"sky-stuff",
|
303 |
+
]
|
304 |
+
self.cocostuff3_coarse_classes = [23, 22, 21]
|
305 |
+
self.first_stuff_index = 12
|
306 |
+
|
307 |
+
def __getitem__(self, index):
|
308 |
+
image_path = self.image_files[index]
|
309 |
+
label_path = self.label_files[index]
|
310 |
+
seed = np.random.randint(2147483647)
|
311 |
+
random.seed(seed)
|
312 |
+
torch.manual_seed(seed)
|
313 |
+
img = self.transform(Image.open(image_path).convert("RGB"))
|
314 |
+
|
315 |
+
random.seed(seed)
|
316 |
+
torch.manual_seed(seed)
|
317 |
+
label = self.label_transform(Image.open(label_path)).squeeze(0)
|
318 |
+
label[label == 255] = -1 # to be consistent with 10k
|
319 |
+
coarse_label = torch.zeros_like(label)
|
320 |
+
for fine, coarse in self.fine_to_coarse.items():
|
321 |
+
coarse_label[label == fine] = coarse
|
322 |
+
coarse_label[label == -1] = -1
|
323 |
+
|
324 |
+
if self.coarse_labels:
|
325 |
+
coarser_labels = -torch.ones_like(label)
|
326 |
+
for i, c in enumerate(self.cocostuff3_coarse_classes):
|
327 |
+
coarser_labels[coarse_label == c] = i
|
328 |
+
return img, coarser_labels, coarser_labels >= 0
|
329 |
+
else:
|
330 |
+
if self.exclude_things:
|
331 |
+
return img, coarse_label - self.first_stuff_index, (coarse_label >= self.first_stuff_index)
|
332 |
+
else:
|
333 |
+
return img, coarse_label, coarse_label >= 0
|
334 |
+
|
335 |
+
def __len__(self):
|
336 |
+
return len(self.image_files)
|
337 |
+
|
338 |
+
|
339 |
+
class CityscapesSeg(Dataset):
|
340 |
+
def __init__(self, root, image_set, transform, target_transform):
|
341 |
+
super(CityscapesSeg, self).__init__()
|
342 |
+
self.split = image_set
|
343 |
+
self.root = join(root, "cityscapes")
|
344 |
+
if image_set == "train":
|
345 |
+
# our_image_set = "train_extra"
|
346 |
+
# mode = "coarse"
|
347 |
+
our_image_set = "train"
|
348 |
+
mode = "fine"
|
349 |
+
else:
|
350 |
+
our_image_set = image_set
|
351 |
+
mode = "fine"
|
352 |
+
self.inner_loader = Cityscapes(self.root, our_image_set,
|
353 |
+
mode=mode,
|
354 |
+
target_type="semantic",
|
355 |
+
transform=None,
|
356 |
+
target_transform=None)
|
357 |
+
self.transform = transform
|
358 |
+
self.target_transform = target_transform
|
359 |
+
self.first_nonvoid = 7
|
360 |
+
|
361 |
+
def __getitem__(self, index):
|
362 |
+
if self.transform is not None:
|
363 |
+
image, target = self.inner_loader[index]
|
364 |
+
|
365 |
+
seed = np.random.randint(2147483647)
|
366 |
+
random.seed(seed)
|
367 |
+
torch.manual_seed(seed)
|
368 |
+
image = self.transform(image)
|
369 |
+
random.seed(seed)
|
370 |
+
torch.manual_seed(seed)
|
371 |
+
target = self.target_transform(target)
|
372 |
+
|
373 |
+
target = target - self.first_nonvoid
|
374 |
+
target[target < 0] = -1
|
375 |
+
mask = target == -1
|
376 |
+
return image, target.squeeze(0), mask
|
377 |
+
else:
|
378 |
+
return self.inner_loader[index]
|
379 |
+
|
380 |
+
def __len__(self):
|
381 |
+
return len(self.inner_loader)
|
382 |
+
|
383 |
+
|
384 |
+
class CroppedDataset(Dataset):
|
385 |
+
def __init__(self, root, dataset_name, crop_type, crop_ratio, image_set, transform, target_transform):
|
386 |
+
super(CroppedDataset, self).__init__()
|
387 |
+
self.dataset_name = dataset_name
|
388 |
+
self.split = image_set
|
389 |
+
self.root = join(root, "cropped", "{}_{}_crop_{}".format(dataset_name, crop_type, crop_ratio))
|
390 |
+
self.transform = transform
|
391 |
+
self.target_transform = target_transform
|
392 |
+
self.img_dir = join(self.root, "img", self.split)
|
393 |
+
self.label_dir = join(self.root, "label", self.split)
|
394 |
+
self.num_images = len(os.listdir(self.img_dir))
|
395 |
+
assert self.num_images == len(os.listdir(self.label_dir))
|
396 |
+
|
397 |
+
def __getitem__(self, index):
|
398 |
+
image = Image.open(join(self.img_dir, "{}.jpg".format(index))).convert('RGB')
|
399 |
+
target = Image.open(join(self.label_dir, "{}.png".format(index)))
|
400 |
+
|
401 |
+
seed = np.random.randint(2147483647)
|
402 |
+
random.seed(seed)
|
403 |
+
torch.manual_seed(seed)
|
404 |
+
image = self.transform(image)
|
405 |
+
random.seed(seed)
|
406 |
+
torch.manual_seed(seed)
|
407 |
+
target = self.target_transform(target)
|
408 |
+
|
409 |
+
target = target - 1
|
410 |
+
mask = target == -1
|
411 |
+
return image, target.squeeze(0), mask
|
412 |
+
|
413 |
+
def __len__(self):
|
414 |
+
return self.num_images
|
415 |
+
|
416 |
+
|
417 |
+
class MaterializedDataset(Dataset):
|
418 |
+
|
419 |
+
def __init__(self, ds):
|
420 |
+
self.ds = ds
|
421 |
+
self.materialized = []
|
422 |
+
loader = DataLoader(ds, num_workers=12, collate_fn=lambda l: l[0])
|
423 |
+
for batch in tqdm(loader):
|
424 |
+
self.materialized.append(batch)
|
425 |
+
|
426 |
+
def __len__(self):
|
427 |
+
return len(self.ds)
|
428 |
+
|
429 |
+
def __getitem__(self, ind):
|
430 |
+
return self.materialized[ind]
|
431 |
+
|
432 |
+
|
433 |
+
class ContrastiveSegDataset(Dataset):
|
434 |
+
def __init__(self,
|
435 |
+
pytorch_data_dir,
|
436 |
+
dataset_name,
|
437 |
+
crop_type,
|
438 |
+
image_set,
|
439 |
+
transform,
|
440 |
+
target_transform,
|
441 |
+
cfg,
|
442 |
+
aug_geometric_transform=None,
|
443 |
+
aug_photometric_transform=None,
|
444 |
+
num_neighbors=5,
|
445 |
+
compute_knns=False,
|
446 |
+
mask=False,
|
447 |
+
pos_labels=False,
|
448 |
+
pos_images=False,
|
449 |
+
extra_transform=None,
|
450 |
+
model_type_override=None
|
451 |
+
):
|
452 |
+
super(ContrastiveSegDataset).__init__()
|
453 |
+
self.num_neighbors = num_neighbors
|
454 |
+
self.image_set = image_set
|
455 |
+
self.dataset_name = dataset_name
|
456 |
+
self.mask = mask
|
457 |
+
self.pos_labels = pos_labels
|
458 |
+
self.pos_images = pos_images
|
459 |
+
self.extra_transform = extra_transform
|
460 |
+
|
461 |
+
if dataset_name == "potsdam":
|
462 |
+
self.n_classes = 3
|
463 |
+
dataset_class = Potsdam
|
464 |
+
extra_args = dict(coarse_labels=True)
|
465 |
+
elif dataset_name == "potsdamraw":
|
466 |
+
self.n_classes = 3
|
467 |
+
dataset_class = PotsdamRaw
|
468 |
+
extra_args = dict(coarse_labels=True)
|
469 |
+
elif dataset_name == "directory":
|
470 |
+
self.n_classes = cfg.dir_dataset_n_classes
|
471 |
+
dataset_class = DirectoryDataset
|
472 |
+
extra_args = dict(path=cfg.dir_dataset_name)
|
473 |
+
elif dataset_name == "cityscapes" and crop_type is None:
|
474 |
+
self.n_classes = 27
|
475 |
+
dataset_class = CityscapesSeg
|
476 |
+
extra_args = dict()
|
477 |
+
elif dataset_name == "cityscapes" and crop_type is not None:
|
478 |
+
self.n_classes = 27
|
479 |
+
dataset_class = CroppedDataset
|
480 |
+
extra_args = dict(dataset_name="cityscapes", crop_type=crop_type, crop_ratio=cfg.crop_ratio)
|
481 |
+
elif dataset_name == "cocostuff3":
|
482 |
+
self.n_classes = 3
|
483 |
+
dataset_class = Coco
|
484 |
+
extra_args = dict(coarse_labels=True, subset=6, exclude_things=True)
|
485 |
+
elif dataset_name == "cocostuff15":
|
486 |
+
self.n_classes = 15
|
487 |
+
dataset_class = Coco
|
488 |
+
extra_args = dict(coarse_labels=False, subset=7, exclude_things=True)
|
489 |
+
elif dataset_name == "cocostuff27" and crop_type is not None:
|
490 |
+
self.n_classes = 27
|
491 |
+
dataset_class = CroppedDataset
|
492 |
+
extra_args = dict(dataset_name="cocostuff27", crop_type=cfg.crop_type, crop_ratio=cfg.crop_ratio)
|
493 |
+
elif dataset_name == "cocostuff27" and crop_type is None:
|
494 |
+
self.n_classes = 27
|
495 |
+
dataset_class = Coco
|
496 |
+
extra_args = dict(coarse_labels=False, subset=None, exclude_things=False)
|
497 |
+
if image_set == "val":
|
498 |
+
extra_args["subset"] = 7
|
499 |
+
else:
|
500 |
+
raise ValueError("Unknown dataset: {}".format(dataset_name))
|
501 |
+
|
502 |
+
self.aug_geometric_transform = aug_geometric_transform
|
503 |
+
self.aug_photometric_transform = aug_photometric_transform
|
504 |
+
|
505 |
+
self.dataset = dataset_class(
|
506 |
+
root=pytorch_data_dir,
|
507 |
+
image_set=self.image_set,
|
508 |
+
transform=transform,
|
509 |
+
target_transform=target_transform, **extra_args)
|
510 |
+
|
511 |
+
if model_type_override is not None:
|
512 |
+
model_type = model_type_override
|
513 |
+
else:
|
514 |
+
model_type = cfg.model_type
|
515 |
+
|
516 |
+
nice_dataset_name = cfg.dir_dataset_name if dataset_name == "directory" else dataset_name
|
517 |
+
feature_cache_file = join(pytorch_data_dir, "nns", "nns_{}_{}_{}_{}_{}.npz".format(
|
518 |
+
model_type, nice_dataset_name, image_set, crop_type, cfg.res))
|
519 |
+
if pos_labels or pos_images:
|
520 |
+
if not os.path.exists(feature_cache_file) or compute_knns:
|
521 |
+
raise ValueError("could not find nn file {} please run precompute_knns".format(feature_cache_file))
|
522 |
+
else:
|
523 |
+
loaded = np.load(feature_cache_file)
|
524 |
+
self.nns = loaded["nns"]
|
525 |
+
assert len(self.dataset) == self.nns.shape[0]
|
526 |
+
|
527 |
+
def __len__(self):
|
528 |
+
return len(self.dataset)
|
529 |
+
|
530 |
+
def _set_seed(self, seed):
|
531 |
+
random.seed(seed) # apply this seed to img tranfsorms
|
532 |
+
torch.manual_seed(seed) # needed for torchvision 0.7
|
533 |
+
|
534 |
+
def __getitem__(self, ind):
|
535 |
+
pack = self.dataset[ind]
|
536 |
+
|
537 |
+
if self.pos_images or self.pos_labels:
|
538 |
+
ind_pos = self.nns[ind][torch.randint(low=1, high=self.num_neighbors + 1, size=[]).item()]
|
539 |
+
pack_pos = self.dataset[ind_pos]
|
540 |
+
|
541 |
+
seed = np.random.randint(2147483647) # make a seed with numpy generator
|
542 |
+
|
543 |
+
self._set_seed(seed)
|
544 |
+
coord_entries = torch.meshgrid([torch.linspace(-1, 1, pack[0].shape[1]),
|
545 |
+
torch.linspace(-1, 1, pack[0].shape[2])])
|
546 |
+
coord = torch.cat([t.unsqueeze(0) for t in coord_entries], 0)
|
547 |
+
|
548 |
+
if self.extra_transform is not None:
|
549 |
+
extra_trans = self.extra_transform
|
550 |
+
else:
|
551 |
+
extra_trans = lambda i, x: x
|
552 |
+
|
553 |
+
def squeeze_tuple(label_raw):
|
554 |
+
if type(label_raw) == tuple:
|
555 |
+
return tuple(x.squeeze() for x in label_raw)
|
556 |
+
else:
|
557 |
+
return label_raw.squeeze()
|
558 |
+
ret = {
|
559 |
+
"ind": ind,
|
560 |
+
"img": extra_trans(ind, pack[0]),
|
561 |
+
"label": squeeze_tuple(extra_trans(ind, pack[1]))
|
562 |
+
}
|
563 |
+
|
564 |
+
if self.pos_images:
|
565 |
+
ret["img_pos"] = extra_trans(ind, pack_pos[0])
|
566 |
+
ret["ind_pos"] = ind_pos
|
567 |
+
|
568 |
+
if self.mask:
|
569 |
+
ret["mask"] = pack[2]
|
570 |
+
|
571 |
+
if self.pos_labels:
|
572 |
+
ret["label_pos"] = squeeze_tuple(extra_trans(ind, pack_pos[1]))
|
573 |
+
ret["mask_pos"] = pack_pos[2]
|
574 |
+
|
575 |
+
if self.aug_photometric_transform is not None:
|
576 |
+
img_aug = self.aug_photometric_transform(self.aug_geometric_transform(pack[0]))
|
577 |
+
|
578 |
+
self._set_seed(seed)
|
579 |
+
coord_aug = self.aug_geometric_transform(coord)
|
580 |
+
|
581 |
+
ret["img_aug"] = img_aug
|
582 |
+
ret["coord_aug"] = coord_aug.permute(1, 2, 0)
|
583 |
+
|
584 |
+
return ret
|
biomap/dataset_generator/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data_loader import DataLoader
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = [
|
5 |
+
'DataLoader',
|
6 |
+
]
|
biomap/dataset_generator/data_loader.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import ee
|
3 |
+
from func_timeout import func_set_timeout
|
4 |
+
import pandas as pd
|
5 |
+
from PIL import Image
|
6 |
+
import requests
|
7 |
+
import tempfile
|
8 |
+
import io
|
9 |
+
from tqdm import tqdm
|
10 |
+
import functools
|
11 |
+
import re # Used in an eval statement
|
12 |
+
from typing import List
|
13 |
+
from typing import Union
|
14 |
+
from typing import Any
|
15 |
+
|
16 |
+
|
17 |
+
class DataLoader:
|
18 |
+
"""
|
19 |
+
Main class for loading and exploring data from satellite images.
|
20 |
+
The goal is to load an ImageCollection and to filter that collection according to needs, with methods like
|
21 |
+
filter, filterDate, filterBounds, select. These will work just like earth engine's methods with the same names.
|
22 |
+
|
23 |
+
This class, just like earth engine, works with lazy loading and compute. This means that running filterBounds
|
24 |
+
will not actually filter the image collection until required, e.g. when counting the images by accessing .count
|
25 |
+
property.
|
26 |
+
However, it will only load once the information it needs, unless additional filtering is made.
|
27 |
+
|
28 |
+
This works thanks to the signal_change decorator. If you develop a new filtering method for this class,
|
29 |
+
you will need to decorate your method with @signal_change.
|
30 |
+
In addition, if you develop a new method that will require to run getInfo to actually load data from
|
31 |
+
Google Earth Engine, you will need to use _get_timeout_info(your object before getInfo). This will run
|
32 |
+
getInfo with a timeout (currently set to 10 seconds).
|
33 |
+
It is important to use a timeout to avoid unexpected run times.
|
34 |
+
|
35 |
+
Usage:
|
36 |
+
>>> dl = DataLoader(satellite_name="COPERNICUS/S2_SR", \
|
37 |
+
start_date='2021-01-01', \
|
38 |
+
end_date='2021-01-15', \
|
39 |
+
bands=["TCI_R", "TCI_G", "TCI_B"], \
|
40 |
+
geographic_bounds=ee.Geometry.Point(*[5.238728194366604, 44.474864056855935]).buffer(500) \
|
41 |
+
)
|
42 |
+
|
43 |
+
Get a pandas dataframe with all pixel values as a timeseries:
|
44 |
+
>>> dl.getRegion(dl.bounds, 500)
|
45 |
+
>>> dl.region.head(2)
|
46 |
+
[Out]
|
47 |
+
id longitude latitude time B1 B2 B3 B4 B5 B6 ... WVP SCL TCI_R TCI_G TCI_B MSK_CLDPRB MSK_SNWPRB QA10 QA20 QA60
|
48 |
+
0 20210102T104441_20210102T104435_T31TFK 5.234932 44.473344 2021-01-02 10:48:36.299 6297 5955 5768 5773 5965 5883 ... 393 8 255 255 255 0 95 0 0 1024
|
49 |
+
1 20210104T103329_20210104T103331_T31TFK 5.234932 44.473344 2021-01-04 10:38:38.304 5547 5355 5184 5090 5254 5229 ... 314 9 255 255 255 29 9 0 0 1024
|
50 |
+
|
51 |
+
>>> dl.date_range
|
52 |
+
[Out]
|
53 |
+
{'max': datetime.datetime(2021, 1, 14, 11, 38, 39, 208000),
|
54 |
+
'min': datetime.datetime(2021, 1, 2, 11, 48, 36, 299000)}
|
55 |
+
|
56 |
+
>>> dl.count
|
57 |
+
[Out]
|
58 |
+
6
|
59 |
+
|
60 |
+
>>> dl.collection_info # constains a html description of the dataset in "description"
|
61 |
+
|
62 |
+
>>> dl.image_ids
|
63 |
+
[Out]
|
64 |
+
['COPERNICUS/S2_SR/20210102T104441_20210102T104435_T31TFK',
|
65 |
+
'COPERNICUS/S2_SR/20210104T103329_20210104T103331_T31TFK',
|
66 |
+
'COPERNICUS/S2_SR/20210107T104329_20210107T104328_T31TFK',
|
67 |
+
'COPERNICUS/S2_SR/20210109T103421_20210109T103431_T31TFK',
|
68 |
+
'COPERNICUS/S2_SR/20210112T104411_20210112T104438_T31TFK',
|
69 |
+
'COPERNICUS/S2_SR/20210114T103309_20210114T103305_T31TFK']
|
70 |
+
|
71 |
+
# Download the image
|
72 |
+
>>> img = dl.download_image(dl.image_ids[3])
|
73 |
+
|
74 |
+
# Download all images as a list
|
75 |
+
>>> imgs = dl.download_all_images(scale=1)
|
76 |
+
|
77 |
+
"""
|
78 |
+
def __init__(self,
|
79 |
+
satellite_name: str,
|
80 |
+
bands: Union[List, str] = None,
|
81 |
+
start_date: str = None,
|
82 |
+
end_date: str = None,
|
83 |
+
geographic_bounds: ee.geometry = None,
|
84 |
+
scale: int = 10,
|
85 |
+
crs: str = "EPSG:32630"
|
86 |
+
):
|
87 |
+
"""
|
88 |
+
|
89 |
+
Args:
|
90 |
+
satellite_name: satellite to use. Examples: COPERNICUS/S2_SR, COPERNICUS/CORINE/V20/100m.
|
91 |
+
See https://developers.google.com/earth-engine/datasets for the full list.
|
92 |
+
bands: list of bands to load.
|
93 |
+
start_date: lowest possible date. Might be lower than the actual date of the first picture.
|
94 |
+
end_date: Latest possible date.
|
95 |
+
geographic_bounds: Region of interest.
|
96 |
+
"""
|
97 |
+
self.satellite_name = satellite_name
|
98 |
+
if isinstance(bands, str):
|
99 |
+
bands = [bands]
|
100 |
+
self.bands = bands if bands is not None else list()
|
101 |
+
if start_date is None or end_date is None:
|
102 |
+
assert (start_date is not None) and (end_date is not None), "start_date and end_date must both be provided"
|
103 |
+
self.start_date = start_date
|
104 |
+
self.end_date = end_date
|
105 |
+
self.bounds = geographic_bounds
|
106 |
+
|
107 |
+
# Lazy computed
|
108 |
+
self._available_images = None
|
109 |
+
|
110 |
+
# Start getting info from google cloud
|
111 |
+
if satellite_name:
|
112 |
+
self.image_collection = ee.ImageCollection(self.satellite_name)
|
113 |
+
if self.bounds:
|
114 |
+
self.filterBounds(self.bounds)
|
115 |
+
if self.start_date is not None:
|
116 |
+
self.filterDate(self.start_date, self.end_date)
|
117 |
+
self.scale = scale
|
118 |
+
self.crs = crs
|
119 |
+
self.image_list = None
|
120 |
+
self._df_image_list = None
|
121 |
+
self.image_collection_info = None
|
122 |
+
self._date_range = None
|
123 |
+
self.date_filter_change = False
|
124 |
+
self._count = None
|
125 |
+
|
126 |
+
# Bool for caching
|
127 |
+
self.filter_change = True
|
128 |
+
self._describe = None
|
129 |
+
|
130 |
+
def signal_change(func):
|
131 |
+
"""Signals that additional filtering was performed. To be used
|
132 |
+
as a decorator."""
|
133 |
+
@functools.wraps(func)
|
134 |
+
def wrap(self, *args, **kwargs):
|
135 |
+
self.filter_change = True
|
136 |
+
self.date_filter_change = True
|
137 |
+
return func(self, *args, **kwargs)
|
138 |
+
return wrap
|
139 |
+
|
140 |
+
@staticmethod
|
141 |
+
@func_set_timeout(10)
|
142 |
+
def _get_timeout_info(instance: Any):
|
143 |
+
"""Runs getInfo on anything that is passed, with a timeout."""
|
144 |
+
return instance.getInfo()
|
145 |
+
|
146 |
+
@staticmethod
|
147 |
+
def _authenticate_gee():
|
148 |
+
"""Authenticates earth engine if needed, and initializes."""
|
149 |
+
try:
|
150 |
+
ee.Initialize()
|
151 |
+
except Exception as e:
|
152 |
+
# Trigger the authentication flow.
|
153 |
+
ee.Authenticate()
|
154 |
+
# Initialize the library.
|
155 |
+
ee.Initialize()
|
156 |
+
|
157 |
+
def filter(self, ee_filter: ee.Filter):
|
158 |
+
"""Applies a filter to the image_collection attribute. This can be useful for example
|
159 |
+
to filter out clouds
|
160 |
+
|
161 |
+
Args:
|
162 |
+
ee_filter: Filter to apply, must be an instance of ee.Filter.
|
163 |
+
|
164 |
+
Returns: self, for operation chaining as possible with the earth engine API.
|
165 |
+
|
166 |
+
"""
|
167 |
+
self.image_collection = self.image_collection.filter(ee_filter)
|
168 |
+
|
169 |
+
return self
|
170 |
+
|
171 |
+
@property
|
172 |
+
def count(self):
|
173 |
+
"""Number of images in the ImageCollection"""
|
174 |
+
if self.filter_change or self._count is None:
|
175 |
+
self._count = self._get_timeout_info(self.image_collection.size())
|
176 |
+
self.filter_change = False
|
177 |
+
return self._count
|
178 |
+
|
179 |
+
@property
|
180 |
+
def available_images(self):
|
181 |
+
"""Gets the ImageCollection info"""
|
182 |
+
if self.filter_change or self._available_images is None:
|
183 |
+
self._available_images = self._get_timeout_info(self.image_collection)
|
184 |
+
return self._available_images
|
185 |
+
|
186 |
+
@signal_change
|
187 |
+
def filterDate(self, *args, **kwargs):
|
188 |
+
"""Wrapper for the filterDate method in earth engine on the ImageCollection"""
|
189 |
+
self.image_collection = self.image_collection.filterDate(*args, **kwargs)
|
190 |
+
return self
|
191 |
+
|
192 |
+
@signal_change
|
193 |
+
def getRegion(self, *args, **kwargs):
|
194 |
+
"""Wrapper for the getRegion method in earth engine on the ImageCollection.
|
195 |
+
Caveat! getRegion does not return an image collection, so the image_list attribute gets
|
196 |
+
updated instead of the image_collection attribute. However, the instance of the DataLoader class
|
197 |
+
is still returned, so this could be chained with another method on ImageCollection, which wouldn't be
|
198 |
+
possible using earth engine.
|
199 |
+
"""
|
200 |
+
self.image_list = self.image_collection.getRegion(*args, **kwargs)
|
201 |
+
return self
|
202 |
+
|
203 |
+
@signal_change
|
204 |
+
def filterBounds(self, geometry, *args, **kwargs):
|
205 |
+
"""Wrapper for the filterBounds method in earth engine on the ImageCollection"""
|
206 |
+
self.image_collection = self.image_collection.filterBounds(geometry, *args, **kwargs)
|
207 |
+
self.bounds = geometry
|
208 |
+
return self
|
209 |
+
|
210 |
+
@signal_change
|
211 |
+
def select(self, *bands, **kwargs):
|
212 |
+
"""Wrapper for the select method in earth engine on the ImageCollection"""
|
213 |
+
self.image_collection = self.image_collection.select(*bands, **kwargs)
|
214 |
+
self.bands = list(set(self.bands) | set(bands)) # Unique bands
|
215 |
+
return self
|
216 |
+
|
217 |
+
@property
|
218 |
+
def date_range(self):
|
219 |
+
"""Gets the actual date range of the images in the image collection."""
|
220 |
+
if self.date_filter_change or self._date_range is None:
|
221 |
+
date_range = self.image_collection.reduceColumns(ee.Reducer.minMax(), ["system:time_start"]).getInfo()
|
222 |
+
self._date_range = {key: datetime.fromtimestamp(value/1e3) for key, value in date_range.items()}
|
223 |
+
self.date_filter_change = False
|
224 |
+
return self._date_range
|
225 |
+
|
226 |
+
@property
|
227 |
+
def region(self):
|
228 |
+
"""Gets a time series as a pandas DataFrame of the band values for the specified region."""
|
229 |
+
if self.filter_change:
|
230 |
+
if self.image_list is None:
|
231 |
+
self.getRegion()
|
232 |
+
res_list = self._get_timeout_info(self.image_list)
|
233 |
+
df = pd.DataFrame(res_list[1:], columns=res_list[0])
|
234 |
+
df.loc[:, "time"] = pd.to_datetime(df.loc[:, "time"], unit="ms")
|
235 |
+
self._df_image_list = df
|
236 |
+
self.filter_change = False
|
237 |
+
return self._df_image_list
|
238 |
+
|
239 |
+
@property
|
240 |
+
def collection_info(self):
|
241 |
+
"""Runs getInfo on the image collection (the first time the next time the previously
|
242 |
+
populated attribute will be returned)."""
|
243 |
+
if self.count > 5000:
|
244 |
+
raise Exception("Too many images to load. Try filtering more")
|
245 |
+
if self.filter_change or self.image_collection_info is None:
|
246 |
+
self.image_collection_info = self._get_timeout_info(self.image_collection)
|
247 |
+
return self.image_collection_info
|
248 |
+
|
249 |
+
@property
|
250 |
+
def image_ids(self):
|
251 |
+
"""list of names of available images in the image collection"""
|
252 |
+
return [i["id"] for i in self.collection_info["features"]]
|
253 |
+
|
254 |
+
def __repr__(self):
|
255 |
+
try:
|
256 |
+
return f"""
|
257 |
+
Size: {self.count}
|
258 |
+
|
259 |
+
Dataset date ranges:
|
260 |
+
From: {self.date_range["min"]}
|
261 |
+
To: {self.date_range["max"]}
|
262 |
+
|
263 |
+
Selected bands:
|
264 |
+
{self.bands}
|
265 |
+
|
266 |
+
"""
|
267 |
+
except Exception as e:
|
268 |
+
raise Exception("Impossible to represent the dataset. Try filtering more. Error handling to do.")
|
269 |
+
|
270 |
+
def reproject(self, image, **kwargs):
|
271 |
+
def resolve(name: str):
|
272 |
+
# Resolve crs
|
273 |
+
if name in kwargs:
|
274 |
+
item = kwargs[name]
|
275 |
+
elif getattr(self, name):
|
276 |
+
item = getattr(self, name)
|
277 |
+
else:
|
278 |
+
item = None
|
279 |
+
return item
|
280 |
+
crs = resolve("crs")
|
281 |
+
scale = resolve("scale")
|
282 |
+
if crs is not None or scale is not None:
|
283 |
+
image = image.reproject(crs, None, scale)
|
284 |
+
return image
|
285 |
+
|
286 |
+
def download_image(self, image_id: str, **kwargs):
|
287 |
+
"""Downloads an image based on its id / name. The additional arguments are passed
|
288 |
+
to getThumbUrl, and could be scale, max, min...
|
289 |
+
"""
|
290 |
+
img = ee.Image(image_id).select(*self.bands)
|
291 |
+
img = self.reproject(img, **kwargs)
|
292 |
+
input_args = {'region': self.bounds}
|
293 |
+
input_args.update(**kwargs)
|
294 |
+
all_bands = self.collection_info["features"][0]["bands"]
|
295 |
+
selected_bands = [band for i, band in enumerate(all_bands) if all_bands[i]["id"] in self.bands]
|
296 |
+
if "min" not in input_args:
|
297 |
+
input_args.update({"min": selected_bands[0]["data_type"]["min"]})
|
298 |
+
if "max" not in input_args:
|
299 |
+
input_args.update({"max": selected_bands[0]["data_type"]["max"]})
|
300 |
+
url = img.getThumbUrl(input_args)
|
301 |
+
buffer = tempfile.SpooledTemporaryFile(max_size=1e9)
|
302 |
+
r = requests.get(url, stream=True)
|
303 |
+
if r.status_code == 200:
|
304 |
+
downloaded = 0
|
305 |
+
# filesize = int(r.headers['content-length'])
|
306 |
+
for chunk in r.iter_content(chunk_size=1024):
|
307 |
+
downloaded += len(chunk)
|
308 |
+
buffer.write(chunk)
|
309 |
+
buffer.seek(0)
|
310 |
+
img = Image.open(io.BytesIO(buffer.read()))
|
311 |
+
buffer.close()
|
312 |
+
return img
|
313 |
+
|
314 |
+
@staticmethod
|
315 |
+
def _regex(regex: str, im_id_list: List[str], include: bool) -> list:
|
316 |
+
"""
|
317 |
+
Filters the im_id_list based on a regular expression. This is useful before downloading
|
318 |
+
a collection of images. For example, using (.*)TXT with include=True will only download images
|
319 |
+
that end with TXT, wich for Nantes means filtering out empty or half empty images.
|
320 |
+
Args:
|
321 |
+
regex: python regex as a strng
|
322 |
+
im_id_list: list, image id list
|
323 |
+
include: whether to include or exclude elements that match the regex.
|
324 |
+
|
325 |
+
Returns: filtered list.
|
326 |
+
|
327 |
+
"""
|
328 |
+
expression = "re.match('{regex}', '{im_id}') is not None"
|
329 |
+
if not include:
|
330 |
+
expression = "not " + expression
|
331 |
+
filtered_list = list()
|
332 |
+
for im_id in im_id_list:
|
333 |
+
if eval(expression.format(regex=regex, im_id=im_id)):
|
334 |
+
filtered_list.append(im_id)
|
335 |
+
return filtered_list
|
336 |
+
|
337 |
+
def download_all_images(self, regex_exclude: str = None, regex_include: str = None, **kwargs):
|
338 |
+
"""
|
339 |
+
Runs download_image in a for loop around the available images.
|
340 |
+
Makes it possible to filter images to download based on a regex.
|
341 |
+
Args:
|
342 |
+
regex_exclude: any image that matches this regex will be excluded.
|
343 |
+
regex_include: any image that matches this regex will be included
|
344 |
+
**kwargs: arguments to be passed to getThumbUrl
|
345 |
+
|
346 |
+
Returns: list of PIL images
|
347 |
+
"""
|
348 |
+
images = list()
|
349 |
+
image_ids = self.image_ids
|
350 |
+
if regex_exclude is not None:
|
351 |
+
image_ids = self._regex(regex_exclude, image_ids, include=False)
|
352 |
+
if regex_include is not None:
|
353 |
+
image_ids = self._regex(regex_include, image_ids, include=True)
|
354 |
+
for i in tqdm(range(len(image_ids))):
|
355 |
+
images.append(self.download_image(image_ids[i], **kwargs))
|
356 |
+
return images
|
biomap/dino/utils.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Misc functions.
|
16 |
+
|
17 |
+
Mostly copy-paste from torchvision references or other public repos like DETR:
|
18 |
+
https://github.com/facebookresearch/detr/blob/master/util/misc.py
|
19 |
+
"""
|
20 |
+
import os
|
21 |
+
import sys
|
22 |
+
import time
|
23 |
+
import math
|
24 |
+
import random
|
25 |
+
import datetime
|
26 |
+
import subprocess
|
27 |
+
from collections import defaultdict, deque
|
28 |
+
|
29 |
+
import numpy as np
|
30 |
+
import torch
|
31 |
+
from torch import nn
|
32 |
+
import torch.distributed as dist
|
33 |
+
from PIL import ImageFilter, ImageOps
|
34 |
+
|
35 |
+
|
36 |
+
class GaussianBlur(object):
|
37 |
+
"""
|
38 |
+
Apply Gaussian Blur to the PIL image.
|
39 |
+
"""
|
40 |
+
def __init__(self, p=0.5, radius_min=0.1, radius_max=2.):
|
41 |
+
self.prob = p
|
42 |
+
self.radius_min = radius_min
|
43 |
+
self.radius_max = radius_max
|
44 |
+
|
45 |
+
def __call__(self, img):
|
46 |
+
do_it = random.random() <= self.prob
|
47 |
+
if not do_it:
|
48 |
+
return img
|
49 |
+
|
50 |
+
return img.filter(
|
51 |
+
ImageFilter.GaussianBlur(
|
52 |
+
radius=random.uniform(self.radius_min, self.radius_max)
|
53 |
+
)
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
class Solarization(object):
|
58 |
+
"""
|
59 |
+
Apply Solarization to the PIL image.
|
60 |
+
"""
|
61 |
+
def __init__(self, p):
|
62 |
+
self.p = p
|
63 |
+
|
64 |
+
def __call__(self, img):
|
65 |
+
if random.random() < self.p:
|
66 |
+
return ImageOps.solarize(img)
|
67 |
+
else:
|
68 |
+
return img
|
69 |
+
|
70 |
+
|
71 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_name, patch_size):
|
72 |
+
if os.path.isfile(pretrained_weights):
|
73 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
74 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
75 |
+
print(f"Take key {checkpoint_key} in provided checkpoint dict")
|
76 |
+
state_dict = state_dict[checkpoint_key]
|
77 |
+
# remove `module.` prefix
|
78 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
79 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
80 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
81 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
82 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
|
83 |
+
else:
|
84 |
+
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
|
85 |
+
url = None
|
86 |
+
if model_name == "vit_small" and patch_size == 16:
|
87 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
88 |
+
elif model_name == "vit_small" and patch_size == 8:
|
89 |
+
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
|
90 |
+
elif model_name == "vit_base" and patch_size == 16:
|
91 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
92 |
+
elif model_name == "vit_base" and patch_size == 8:
|
93 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
94 |
+
if url is not None:
|
95 |
+
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
96 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
97 |
+
model.load_state_dict(state_dict, strict=True)
|
98 |
+
else:
|
99 |
+
print("There is no reference weights available for this model => We use random weights.")
|
100 |
+
|
101 |
+
|
102 |
+
def clip_gradients(model, clip):
|
103 |
+
norms = []
|
104 |
+
for name, p in model.named_parameters():
|
105 |
+
if p.grad is not None:
|
106 |
+
param_norm = p.grad.data.norm(2)
|
107 |
+
norms.append(param_norm.item())
|
108 |
+
clip_coef = clip / (param_norm + 1e-6)
|
109 |
+
if clip_coef < 1:
|
110 |
+
p.grad.data.mul_(clip_coef)
|
111 |
+
return norms
|
112 |
+
|
113 |
+
|
114 |
+
def cancel_gradients_last_layer(epoch, model, freeze_last_layer):
|
115 |
+
if epoch >= freeze_last_layer:
|
116 |
+
return
|
117 |
+
for n, p in model.named_parameters():
|
118 |
+
if "last_layer" in n:
|
119 |
+
p.grad = None
|
120 |
+
|
121 |
+
|
122 |
+
def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
|
123 |
+
"""
|
124 |
+
Re-start from checkpoint
|
125 |
+
"""
|
126 |
+
if not os.path.isfile(ckp_path):
|
127 |
+
return
|
128 |
+
print("Found checkpoint at {}".format(ckp_path))
|
129 |
+
|
130 |
+
# open checkpoint file
|
131 |
+
checkpoint = torch.load(ckp_path, map_location="cpu")
|
132 |
+
|
133 |
+
# key is what to look for in the checkpoint file
|
134 |
+
# value is the object to load
|
135 |
+
# example: {'state_dict': model}
|
136 |
+
for key, value in kwargs.items():
|
137 |
+
if key in checkpoint and value is not None:
|
138 |
+
try:
|
139 |
+
msg = value.load_state_dict(checkpoint[key], strict=False)
|
140 |
+
print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
|
141 |
+
except TypeError:
|
142 |
+
try:
|
143 |
+
msg = value.load_state_dict(checkpoint[key])
|
144 |
+
print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
|
145 |
+
except ValueError:
|
146 |
+
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
|
147 |
+
else:
|
148 |
+
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
|
149 |
+
|
150 |
+
# re load variable important for the run
|
151 |
+
if run_variables is not None:
|
152 |
+
for var_name in run_variables:
|
153 |
+
if var_name in checkpoint:
|
154 |
+
run_variables[var_name] = checkpoint[var_name]
|
155 |
+
|
156 |
+
|
157 |
+
def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
|
158 |
+
warmup_schedule = np.array([])
|
159 |
+
warmup_iters = warmup_epochs * niter_per_ep
|
160 |
+
if warmup_epochs > 0:
|
161 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
162 |
+
|
163 |
+
iters = np.arange(epochs * niter_per_ep - warmup_iters)
|
164 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
165 |
+
|
166 |
+
schedule = np.concatenate((warmup_schedule, schedule))
|
167 |
+
assert len(schedule) == epochs * niter_per_ep
|
168 |
+
return schedule
|
169 |
+
|
170 |
+
|
171 |
+
def bool_flag(s):
|
172 |
+
"""
|
173 |
+
Parse boolean arguments from the command line.
|
174 |
+
"""
|
175 |
+
FALSY_STRINGS = {"off", "false", "0"}
|
176 |
+
TRUTHY_STRINGS = {"on", "true", "1"}
|
177 |
+
if s.lower() in FALSY_STRINGS:
|
178 |
+
return False
|
179 |
+
elif s.lower() in TRUTHY_STRINGS:
|
180 |
+
return True
|
181 |
+
else:
|
182 |
+
raise argparse.ArgumentTypeError("invalid value for a boolean flag")
|
183 |
+
|
184 |
+
|
185 |
+
def fix_random_seeds(seed=31):
|
186 |
+
"""
|
187 |
+
Fix random seeds.
|
188 |
+
"""
|
189 |
+
torch.manual_seed(seed)
|
190 |
+
torch.cuda.manual_seed_all(seed)
|
191 |
+
np.random.seed(seed)
|
192 |
+
|
193 |
+
|
194 |
+
class SmoothedValue(object):
|
195 |
+
"""Track a series of values and provide access to smoothed values over a
|
196 |
+
window or the global series average.
|
197 |
+
"""
|
198 |
+
|
199 |
+
def __init__(self, window_size=20, fmt=None):
|
200 |
+
if fmt is None:
|
201 |
+
fmt = "{median:.6f} ({global_avg:.6f})"
|
202 |
+
self.deque = deque(maxlen=window_size)
|
203 |
+
self.total = 0.0
|
204 |
+
self.count = 0
|
205 |
+
self.fmt = fmt
|
206 |
+
|
207 |
+
def update(self, value, n=1):
|
208 |
+
self.deque.append(value)
|
209 |
+
self.count += n
|
210 |
+
self.total += value * n
|
211 |
+
|
212 |
+
def synchronize_between_processes(self):
|
213 |
+
"""
|
214 |
+
Warning: does not synchronize the deque!
|
215 |
+
"""
|
216 |
+
if not is_dist_avail_and_initialized():
|
217 |
+
return
|
218 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
219 |
+
dist.barrier()
|
220 |
+
dist.all_reduce(t)
|
221 |
+
t = t.tolist()
|
222 |
+
self.count = int(t[0])
|
223 |
+
self.total = t[1]
|
224 |
+
|
225 |
+
@property
|
226 |
+
def median(self):
|
227 |
+
d = torch.tensor(list(self.deque))
|
228 |
+
return d.median().item()
|
229 |
+
|
230 |
+
@property
|
231 |
+
def avg(self):
|
232 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
233 |
+
return d.mean().item()
|
234 |
+
|
235 |
+
@property
|
236 |
+
def global_avg(self):
|
237 |
+
return self.total / self.count
|
238 |
+
|
239 |
+
@property
|
240 |
+
def max(self):
|
241 |
+
return max(self.deque)
|
242 |
+
|
243 |
+
@property
|
244 |
+
def value(self):
|
245 |
+
return self.deque[-1]
|
246 |
+
|
247 |
+
def __str__(self):
|
248 |
+
return self.fmt.format(
|
249 |
+
median=self.median,
|
250 |
+
avg=self.avg,
|
251 |
+
global_avg=self.global_avg,
|
252 |
+
max=self.max,
|
253 |
+
value=self.value)
|
254 |
+
|
255 |
+
|
256 |
+
def reduce_dict(input_dict, average=True):
|
257 |
+
"""
|
258 |
+
Args:
|
259 |
+
input_dict (dict): all the values will be reduced
|
260 |
+
average (bool): whether to do average or sum
|
261 |
+
Reduce the values in the dictionary from all processes so that all processes
|
262 |
+
have the averaged results. Returns a dict with the same fields as
|
263 |
+
input_dict, after reduction.
|
264 |
+
"""
|
265 |
+
world_size = get_world_size()
|
266 |
+
if world_size < 2:
|
267 |
+
return input_dict
|
268 |
+
with torch.no_grad():
|
269 |
+
names = []
|
270 |
+
values = []
|
271 |
+
# sort the keys so that they are consistent across processes
|
272 |
+
for k in sorted(input_dict.keys()):
|
273 |
+
names.append(k)
|
274 |
+
values.append(input_dict[k])
|
275 |
+
values = torch.stack(values, dim=0)
|
276 |
+
dist.all_reduce(values)
|
277 |
+
if average:
|
278 |
+
values /= world_size
|
279 |
+
reduced_dict = {k: v for k, v in zip(names, values)}
|
280 |
+
return reduced_dict
|
281 |
+
|
282 |
+
|
283 |
+
class MetricLogger(object):
|
284 |
+
def __init__(self, delimiter="\t"):
|
285 |
+
self.meters = defaultdict(SmoothedValue)
|
286 |
+
self.delimiter = delimiter
|
287 |
+
|
288 |
+
def update(self, **kwargs):
|
289 |
+
for k, v in kwargs.items():
|
290 |
+
if isinstance(v, torch.Tensor):
|
291 |
+
v = v.item()
|
292 |
+
assert isinstance(v, (float, int))
|
293 |
+
self.meters[k].update(v)
|
294 |
+
|
295 |
+
def __getattr__(self, attr):
|
296 |
+
if attr in self.meters:
|
297 |
+
return self.meters[attr]
|
298 |
+
if attr in self.__dict__:
|
299 |
+
return self.__dict__[attr]
|
300 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
301 |
+
type(self).__name__, attr))
|
302 |
+
|
303 |
+
def __str__(self):
|
304 |
+
loss_str = []
|
305 |
+
for name, meter in self.meters.items():
|
306 |
+
loss_str.append(
|
307 |
+
"{}: {}".format(name, str(meter))
|
308 |
+
)
|
309 |
+
return self.delimiter.join(loss_str)
|
310 |
+
|
311 |
+
def synchronize_between_processes(self):
|
312 |
+
for meter in self.meters.values():
|
313 |
+
meter.synchronize_between_processes()
|
314 |
+
|
315 |
+
def add_meter(self, name, meter):
|
316 |
+
self.meters[name] = meter
|
317 |
+
|
318 |
+
def log_every(self, iterable, print_freq, header=None):
|
319 |
+
i = 0
|
320 |
+
if not header:
|
321 |
+
header = ''
|
322 |
+
start_time = time.time()
|
323 |
+
end = time.time()
|
324 |
+
iter_time = SmoothedValue(fmt='{avg:.6f}')
|
325 |
+
data_time = SmoothedValue(fmt='{avg:.6f}')
|
326 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
327 |
+
if torch.cuda.is_available():
|
328 |
+
log_msg = self.delimiter.join([
|
329 |
+
header,
|
330 |
+
'[{0' + space_fmt + '}/{1}]',
|
331 |
+
'eta: {eta}',
|
332 |
+
'{meters}',
|
333 |
+
'time: {time}',
|
334 |
+
'data: {data}',
|
335 |
+
'max mem: {memory:.0f}'
|
336 |
+
])
|
337 |
+
else:
|
338 |
+
log_msg = self.delimiter.join([
|
339 |
+
header,
|
340 |
+
'[{0' + space_fmt + '}/{1}]',
|
341 |
+
'eta: {eta}',
|
342 |
+
'{meters}',
|
343 |
+
'time: {time}',
|
344 |
+
'data: {data}'
|
345 |
+
])
|
346 |
+
MB = 1024.0 * 1024.0
|
347 |
+
for obj in iterable:
|
348 |
+
data_time.update(time.time() - end)
|
349 |
+
yield obj
|
350 |
+
iter_time.update(time.time() - end)
|
351 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
352 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
353 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
354 |
+
if torch.cuda.is_available():
|
355 |
+
print(log_msg.format(
|
356 |
+
i, len(iterable), eta=eta_string,
|
357 |
+
meters=str(self),
|
358 |
+
time=str(iter_time), data=str(data_time),
|
359 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
360 |
+
else:
|
361 |
+
print(log_msg.format(
|
362 |
+
i, len(iterable), eta=eta_string,
|
363 |
+
meters=str(self),
|
364 |
+
time=str(iter_time), data=str(data_time)))
|
365 |
+
i += 1
|
366 |
+
end = time.time()
|
367 |
+
total_time = time.time() - start_time
|
368 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
369 |
+
print('{} Total time: {} ({:.6f} s / it)'.format(
|
370 |
+
header, total_time_str, total_time / len(iterable)))
|
371 |
+
|
372 |
+
|
373 |
+
def get_sha():
|
374 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
375 |
+
|
376 |
+
def _run(command):
|
377 |
+
return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
|
378 |
+
sha = 'N/A'
|
379 |
+
diff = "clean"
|
380 |
+
branch = 'N/A'
|
381 |
+
try:
|
382 |
+
sha = _run(['git', 'rev-parse', 'HEAD'])
|
383 |
+
subprocess.check_output(['git', 'diff'], cwd=cwd)
|
384 |
+
diff = _run(['git', 'diff-index', 'HEAD'])
|
385 |
+
diff = "has uncommited changes" if diff else "clean"
|
386 |
+
branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
|
387 |
+
except Exception:
|
388 |
+
pass
|
389 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
390 |
+
return message
|
391 |
+
|
392 |
+
|
393 |
+
def is_dist_avail_and_initialized():
|
394 |
+
if not dist.is_available():
|
395 |
+
return False
|
396 |
+
if not dist.is_initialized():
|
397 |
+
return False
|
398 |
+
return True
|
399 |
+
|
400 |
+
|
401 |
+
def get_world_size():
|
402 |
+
if not is_dist_avail_and_initialized():
|
403 |
+
return 1
|
404 |
+
return dist.get_world_size()
|
405 |
+
|
406 |
+
|
407 |
+
def get_rank():
|
408 |
+
if not is_dist_avail_and_initialized():
|
409 |
+
return 0
|
410 |
+
return dist.get_rank()
|
411 |
+
|
412 |
+
|
413 |
+
def is_main_process():
|
414 |
+
return get_rank() == 0
|
415 |
+
|
416 |
+
|
417 |
+
def save_on_master(*args, **kwargs):
|
418 |
+
if is_main_process():
|
419 |
+
torch.save(*args, **kwargs)
|
420 |
+
|
421 |
+
|
422 |
+
def setup_for_distributed(is_master):
|
423 |
+
"""
|
424 |
+
This function disables printing when not in master process
|
425 |
+
"""
|
426 |
+
import builtins as __builtin__
|
427 |
+
builtin_print = __builtin__.print
|
428 |
+
|
429 |
+
def print(*args, **kwargs):
|
430 |
+
force = kwargs.pop('force', False)
|
431 |
+
if is_master or force:
|
432 |
+
builtin_print(*args, **kwargs)
|
433 |
+
|
434 |
+
__builtin__.print = print
|
435 |
+
|
436 |
+
|
437 |
+
def init_distributed_mode(args):
|
438 |
+
# launched with torch.distributed.launch
|
439 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
440 |
+
args.rank = int(os.environ["RANK"])
|
441 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
442 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
443 |
+
# launched with submitit on a slurm cluster
|
444 |
+
elif 'SLURM_PROCID' in os.environ:
|
445 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
446 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
447 |
+
# launched naively with `python main_dino.py`
|
448 |
+
# we manually add MASTER_ADDR and MASTER_PORT to env variables
|
449 |
+
elif torch.cuda.is_available():
|
450 |
+
print('Will run the code on one GPU.')
|
451 |
+
args.rank, args.gpu, args.world_size = 0, 0, 1
|
452 |
+
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
453 |
+
os.environ['MASTER_PORT'] = '29500'
|
454 |
+
else:
|
455 |
+
print('Does not support training without GPU.')
|
456 |
+
sys.exit(1)
|
457 |
+
|
458 |
+
dist.init_process_group(
|
459 |
+
backend="nccl",
|
460 |
+
init_method=args.dist_url,
|
461 |
+
world_size=args.world_size,
|
462 |
+
rank=args.rank,
|
463 |
+
)
|
464 |
+
|
465 |
+
torch.cuda.set_device(args.gpu)
|
466 |
+
print('| distributed init (rank {}): {}'.format(
|
467 |
+
args.rank, args.dist_url), flush=True)
|
468 |
+
dist.barrier()
|
469 |
+
setup_for_distributed(args.rank == 0)
|
470 |
+
|
471 |
+
|
472 |
+
def accuracy(output, target, topk=(1,)):
|
473 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
474 |
+
maxk = max(topk)
|
475 |
+
batch_size = target.size(0)
|
476 |
+
_, pred = output.topk(maxk, 1, True, True)
|
477 |
+
pred = pred.t()
|
478 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
479 |
+
return [correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]
|
480 |
+
|
481 |
+
|
482 |
+
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
483 |
+
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
484 |
+
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
485 |
+
def norm_cdf(x):
|
486 |
+
# Computes standard normal cumulative distribution function
|
487 |
+
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
488 |
+
|
489 |
+
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
490 |
+
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
491 |
+
"The distribution of values may be incorrect.",
|
492 |
+
stacklevel=2)
|
493 |
+
|
494 |
+
with torch.no_grad():
|
495 |
+
# Values are generated by using a truncated uniform distribution and
|
496 |
+
# then using the inverse CDF for the normal distribution.
|
497 |
+
# Get upper and lower cdf values
|
498 |
+
l = norm_cdf((a - mean) / std)
|
499 |
+
u = norm_cdf((b - mean) / std)
|
500 |
+
|
501 |
+
# Uniformly fill tensor with values from [l, u], then translate to
|
502 |
+
# [2l-1, 2u-1].
|
503 |
+
tensor.uniform_(2 * l - 1, 2 * u - 1)
|
504 |
+
|
505 |
+
# Use inverse cdf transform for normal distribution to get truncated
|
506 |
+
# standard normal
|
507 |
+
tensor.erfinv_()
|
508 |
+
|
509 |
+
# Transform to proper mean, std
|
510 |
+
tensor.mul_(std * math.sqrt(2.))
|
511 |
+
tensor.add_(mean)
|
512 |
+
|
513 |
+
# Clamp to ensure it's in the proper range
|
514 |
+
tensor.clamp_(min=a, max=b)
|
515 |
+
return tensor
|
516 |
+
|
517 |
+
|
518 |
+
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
519 |
+
# type: (Tensor, float, float, float, float) -> Tensor
|
520 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
521 |
+
|
522 |
+
|
523 |
+
class LARS(torch.optim.Optimizer):
|
524 |
+
"""
|
525 |
+
Almost copy-paste from https://github.com/facebookresearch/barlowtwins/blob/main/main.py
|
526 |
+
"""
|
527 |
+
def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, eta=0.001,
|
528 |
+
weight_decay_filter=None, lars_adaptation_filter=None):
|
529 |
+
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum,
|
530 |
+
eta=eta, weight_decay_filter=weight_decay_filter,
|
531 |
+
lars_adaptation_filter=lars_adaptation_filter)
|
532 |
+
super().__init__(params, defaults)
|
533 |
+
|
534 |
+
@torch.no_grad()
|
535 |
+
def step(self):
|
536 |
+
for g in self.param_groups:
|
537 |
+
for p in g['params']:
|
538 |
+
dp = p.grad
|
539 |
+
|
540 |
+
if dp is None:
|
541 |
+
continue
|
542 |
+
|
543 |
+
if p.ndim != 1:
|
544 |
+
dp = dp.add(p, alpha=g['weight_decay'])
|
545 |
+
|
546 |
+
if p.ndim != 1:
|
547 |
+
param_norm = torch.norm(p)
|
548 |
+
update_norm = torch.norm(dp)
|
549 |
+
one = torch.ones_like(param_norm)
|
550 |
+
q = torch.where(param_norm > 0.,
|
551 |
+
torch.where(update_norm > 0,
|
552 |
+
(g['eta'] * param_norm / update_norm), one), one)
|
553 |
+
dp = dp.mul(q)
|
554 |
+
|
555 |
+
param_state = self.state[p]
|
556 |
+
if 'mu' not in param_state:
|
557 |
+
param_state['mu'] = torch.zeros_like(p)
|
558 |
+
mu = param_state['mu']
|
559 |
+
mu.mul_(g['momentum']).add_(dp)
|
560 |
+
|
561 |
+
p.add_(mu, alpha=-g['lr'])
|
562 |
+
|
563 |
+
|
564 |
+
class MultiCropWrapper(nn.Module):
|
565 |
+
"""
|
566 |
+
Perform forward pass separately on each resolution input.
|
567 |
+
The inputs corresponding to a single resolution are clubbed and single
|
568 |
+
forward is run on the same resolution inputs. Hence we do several
|
569 |
+
forward passes = number of different resolutions used. We then
|
570 |
+
concatenate all the output features and run the head forward on these
|
571 |
+
concatenated features.
|
572 |
+
"""
|
573 |
+
def __init__(self, backbone, head):
|
574 |
+
super(MultiCropWrapper, self).__init__()
|
575 |
+
# disable layers dedicated to ImageNet labels classification
|
576 |
+
backbone.fc, backbone.head = nn.Identity(), nn.Identity()
|
577 |
+
self.backbone = backbone
|
578 |
+
self.head = head
|
579 |
+
|
580 |
+
def forward(self, x):
|
581 |
+
# convert to list
|
582 |
+
if not isinstance(x, list):
|
583 |
+
x = [x]
|
584 |
+
idx_crops = torch.cumsum(torch.unique_consecutive(
|
585 |
+
torch.tensor([inp.shape[-1] for inp in x]),
|
586 |
+
return_counts=True,
|
587 |
+
)[1], 0)
|
588 |
+
start_idx = 0
|
589 |
+
for end_idx in idx_crops:
|
590 |
+
_out = self.backbone(torch.cat(x[start_idx: end_idx]))
|
591 |
+
if start_idx == 0:
|
592 |
+
output = _out
|
593 |
+
else:
|
594 |
+
output = torch.cat((output, _out))
|
595 |
+
start_idx = end_idx
|
596 |
+
# Run the head forward on the concatenated features.
|
597 |
+
return self.head(output)
|
598 |
+
|
599 |
+
|
600 |
+
def get_params_groups(model):
|
601 |
+
regularized = []
|
602 |
+
not_regularized = []
|
603 |
+
for name, param in model.named_parameters():
|
604 |
+
if not param.requires_grad:
|
605 |
+
continue
|
606 |
+
# we do not regularize biases nor Norm parameters
|
607 |
+
if name.endswith(".bias") or len(param.shape) == 1:
|
608 |
+
not_regularized.append(param)
|
609 |
+
else:
|
610 |
+
regularized.append(param)
|
611 |
+
return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]
|
612 |
+
|
613 |
+
|
614 |
+
def has_batchnorms(model):
|
615 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
616 |
+
for name, module in model.named_modules():
|
617 |
+
if isinstance(module, bn_types):
|
618 |
+
return True
|
619 |
+
return False
|
biomap/dino/vision_transformer.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""
|
15 |
+
Mostly copy-paste from timm library.
|
16 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
17 |
+
"""
|
18 |
+
import math
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from dino.utils import trunc_normal_
|
24 |
+
|
25 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False):
|
26 |
+
if drop_prob == 0. or not training:
|
27 |
+
return x
|
28 |
+
keep_prob = 1 - drop_prob
|
29 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
30 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
31 |
+
random_tensor.floor_() # binarize
|
32 |
+
output = x.div(keep_prob) * random_tensor
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
class DropPath(nn.Module):
|
37 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
38 |
+
"""
|
39 |
+
def __init__(self, drop_prob=None):
|
40 |
+
super(DropPath, self).__init__()
|
41 |
+
self.drop_prob = drop_prob
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return drop_path(x, self.drop_prob, self.training)
|
45 |
+
|
46 |
+
|
47 |
+
class Mlp(nn.Module):
|
48 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
49 |
+
super().__init__()
|
50 |
+
out_features = out_features or in_features
|
51 |
+
hidden_features = hidden_features or in_features
|
52 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
53 |
+
self.act = act_layer()
|
54 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
55 |
+
self.drop = nn.Dropout(drop)
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
x = self.fc1(x)
|
59 |
+
x = self.act(x)
|
60 |
+
x = self.drop(x)
|
61 |
+
x = self.fc2(x)
|
62 |
+
x = self.drop(x)
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class Attention(nn.Module):
|
67 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
68 |
+
super().__init__()
|
69 |
+
self.num_heads = num_heads
|
70 |
+
head_dim = dim // num_heads
|
71 |
+
self.scale = qk_scale or head_dim ** -0.5
|
72 |
+
|
73 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
74 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
75 |
+
self.proj = nn.Linear(dim, dim)
|
76 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
77 |
+
|
78 |
+
def forward(self, x, return_qkv=False):
|
79 |
+
B, N, C = x.shape
|
80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
81 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
82 |
+
|
83 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
84 |
+
attn = attn.softmax(dim=-1)
|
85 |
+
attn = self.attn_drop(attn)
|
86 |
+
|
87 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
88 |
+
x = self.proj(x)
|
89 |
+
x = self.proj_drop(x)
|
90 |
+
return x,attn, qkv
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
class Block(nn.Module):
|
95 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
96 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
97 |
+
super().__init__()
|
98 |
+
self.norm1 = norm_layer(dim)
|
99 |
+
self.attn = Attention(
|
100 |
+
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
101 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
102 |
+
self.norm2 = norm_layer(dim)
|
103 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
104 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
105 |
+
|
106 |
+
def forward(self, x, return_attention=False, return_qkv = False):
|
107 |
+
y, attn, qkv = self.attn(self.norm1(x))
|
108 |
+
if return_attention:
|
109 |
+
return attn
|
110 |
+
x = x + self.drop_path(y)
|
111 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
112 |
+
if return_qkv:
|
113 |
+
return x,attn, qkv
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class PatchEmbed(nn.Module):
|
118 |
+
""" Image to Patch Embedding
|
119 |
+
"""
|
120 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
121 |
+
super().__init__()
|
122 |
+
num_patches = (img_size // patch_size) * (img_size // patch_size)
|
123 |
+
self.img_size = img_size
|
124 |
+
self.patch_size = patch_size
|
125 |
+
self.num_patches = num_patches
|
126 |
+
|
127 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
128 |
+
|
129 |
+
def forward(self, x):
|
130 |
+
B, C, H, W = x.shape
|
131 |
+
x = self.proj(x).flatten(2).transpose(1, 2)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class VisionTransformer(nn.Module):
|
136 |
+
""" Vision Transformer """
|
137 |
+
def __init__(self, img_size=[224], patch_size=16, in_chans=3, num_classes=0, embed_dim=768, depth=12,
|
138 |
+
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
139 |
+
drop_path_rate=0., norm_layer=nn.LayerNorm, **kwargs):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.num_features = self.embed_dim = embed_dim
|
143 |
+
|
144 |
+
self.patch_embed = PatchEmbed(
|
145 |
+
img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
146 |
+
num_patches = self.patch_embed.num_patches
|
147 |
+
|
148 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
149 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
150 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
151 |
+
|
152 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
153 |
+
self.blocks = nn.ModuleList([
|
154 |
+
Block(
|
155 |
+
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
156 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
157 |
+
for i in range(depth)])
|
158 |
+
self.norm = norm_layer(embed_dim)
|
159 |
+
|
160 |
+
# Classifier head
|
161 |
+
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
162 |
+
|
163 |
+
trunc_normal_(self.pos_embed, std=.02)
|
164 |
+
trunc_normal_(self.cls_token, std=.02)
|
165 |
+
self.apply(self._init_weights)
|
166 |
+
|
167 |
+
def _init_weights(self, m):
|
168 |
+
if isinstance(m, nn.Linear):
|
169 |
+
trunc_normal_(m.weight, std=.02)
|
170 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
171 |
+
nn.init.constant_(m.bias, 0)
|
172 |
+
elif isinstance(m, nn.LayerNorm):
|
173 |
+
nn.init.constant_(m.bias, 0)
|
174 |
+
nn.init.constant_(m.weight, 1.0)
|
175 |
+
|
176 |
+
def interpolate_pos_encoding(self, x, w, h):
|
177 |
+
npatch = x.shape[1] - 1
|
178 |
+
N = self.pos_embed.shape[1] - 1
|
179 |
+
if npatch == N and w == h:
|
180 |
+
return self.pos_embed
|
181 |
+
class_pos_embed = self.pos_embed[:, 0]
|
182 |
+
patch_pos_embed = self.pos_embed[:, 1:]
|
183 |
+
dim = x.shape[-1]
|
184 |
+
w0 = w // self.patch_embed.patch_size
|
185 |
+
h0 = h // self.patch_embed.patch_size
|
186 |
+
# we add a small number to avoid floating point error in the interpolation
|
187 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
188 |
+
w0, h0 = w0 + 0.1, h0 + 0.1
|
189 |
+
patch_pos_embed = nn.functional.interpolate(
|
190 |
+
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
|
191 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
192 |
+
mode='bicubic',
|
193 |
+
)
|
194 |
+
assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
|
195 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
196 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
197 |
+
|
198 |
+
def prepare_tokens(self, x):
|
199 |
+
B, nc, w, h = x.shape
|
200 |
+
x = self.patch_embed(x) # patch linear embedding
|
201 |
+
|
202 |
+
# add the [CLS] token to the embed patch tokens
|
203 |
+
cls_tokens = self.cls_token.expand(B, -1, -1)
|
204 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
205 |
+
|
206 |
+
# add positional encoding to each token
|
207 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
208 |
+
|
209 |
+
return self.pos_drop(x)
|
210 |
+
|
211 |
+
def forward(self, x):
|
212 |
+
x = self.prepare_tokens(x)
|
213 |
+
for blk in self.blocks:
|
214 |
+
x = blk(x)
|
215 |
+
x = self.norm(x)
|
216 |
+
return x[:, 0]
|
217 |
+
|
218 |
+
def forward_feats(self, x):
|
219 |
+
x = self.prepare_tokens(x)
|
220 |
+
for blk in self.blocks:
|
221 |
+
x = blk(x)
|
222 |
+
x = self.norm(x)
|
223 |
+
return x
|
224 |
+
|
225 |
+
def get_intermediate_feat(self, x, n=1):
|
226 |
+
x = self.prepare_tokens(x)
|
227 |
+
# we return the output tokens from the `n` last blocks
|
228 |
+
feat = []
|
229 |
+
attns = []
|
230 |
+
qkvs = []
|
231 |
+
for i, blk in enumerate(self.blocks):
|
232 |
+
x,attn,qkv = blk(x, return_qkv=True)
|
233 |
+
if len(self.blocks) - i <= n:
|
234 |
+
feat.append(self.norm(x))
|
235 |
+
qkvs.append(qkv)
|
236 |
+
attns.append(attn)
|
237 |
+
return feat, attns, qkvs
|
238 |
+
|
239 |
+
def get_last_selfattention(self, x):
|
240 |
+
x = self.prepare_tokens(x)
|
241 |
+
for i, blk in enumerate(self.blocks):
|
242 |
+
if i < len(self.blocks) - 1:
|
243 |
+
x = blk(x)
|
244 |
+
else:
|
245 |
+
# return attention of the last block
|
246 |
+
return blk(x, return_attention=True)
|
247 |
+
|
248 |
+
def get_intermediate_layers(self, x, n=1):
|
249 |
+
x = self.prepare_tokens(x)
|
250 |
+
# we return the output tokens from the `n` last blocks
|
251 |
+
output = []
|
252 |
+
for i, blk in enumerate(self.blocks):
|
253 |
+
x = blk(x)
|
254 |
+
if len(self.blocks) - i <= n:
|
255 |
+
output.append(self.norm(x))
|
256 |
+
return output
|
257 |
+
|
258 |
+
|
259 |
+
def vit_tiny(patch_size=16, **kwargs):
|
260 |
+
model = VisionTransformer(
|
261 |
+
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
|
262 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
263 |
+
return model
|
264 |
+
|
265 |
+
|
266 |
+
def vit_small(patch_size=16, **kwargs):
|
267 |
+
model = VisionTransformer(
|
268 |
+
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
|
269 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
270 |
+
return model
|
271 |
+
|
272 |
+
|
273 |
+
def vit_base(patch_size=16, **kwargs):
|
274 |
+
model = VisionTransformer(
|
275 |
+
patch_size=patch_size, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
276 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
277 |
+
return model
|
278 |
+
|
279 |
+
|
280 |
+
class DINOHead(nn.Module):
|
281 |
+
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
|
282 |
+
super().__init__()
|
283 |
+
nlayers = max(nlayers, 1)
|
284 |
+
if nlayers == 1:
|
285 |
+
self.mlp = nn.Linear(in_dim, bottleneck_dim)
|
286 |
+
else:
|
287 |
+
layers = [nn.Linear(in_dim, hidden_dim)]
|
288 |
+
if use_bn:
|
289 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
290 |
+
layers.append(nn.GELU())
|
291 |
+
for _ in range(nlayers - 2):
|
292 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim))
|
293 |
+
if use_bn:
|
294 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
295 |
+
layers.append(nn.GELU())
|
296 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
|
297 |
+
self.mlp = nn.Sequential(*layers)
|
298 |
+
self.apply(self._init_weights)
|
299 |
+
self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
300 |
+
self.last_layer.weight_g.data.fill_(1)
|
301 |
+
if norm_last_layer:
|
302 |
+
self.last_layer.weight_g.requires_grad = False
|
303 |
+
|
304 |
+
def _init_weights(self, m):
|
305 |
+
if isinstance(m, nn.Linear):
|
306 |
+
trunc_normal_(m.weight, std=.02)
|
307 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
308 |
+
nn.init.constant_(m.bias, 0)
|
309 |
+
|
310 |
+
def forward(self, x):
|
311 |
+
x = self.mlp(x)
|
312 |
+
x = nn.functional.normalize(x, dim=-1, p=2)
|
313 |
+
x = self.last_layer(x)
|
314 |
+
return x
|
biomap/helper.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.multiprocessing
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import numpy as np
|
4 |
+
from utils import transform_to_pil, create_video
|
5 |
+
from utils_gee import extract_img, transform_ee_img
|
6 |
+
from dateutil.relativedelta import relativedelta
|
7 |
+
import datetime
|
8 |
+
from dateutil.relativedelta import relativedelta
|
9 |
+
import cv2
|
10 |
+
|
11 |
+
from joblib import Parallel, cpu_count, delayed
|
12 |
+
|
13 |
+
def get_image(location, d1, d2):
|
14 |
+
print(f"getting image for {d1} to {d2}")
|
15 |
+
try:
|
16 |
+
img = extract_img(location, d1, d2)
|
17 |
+
img_test = transform_ee_img(
|
18 |
+
img, max=0.3
|
19 |
+
)
|
20 |
+
return img_test
|
21 |
+
except Exception as err:
|
22 |
+
print(err)
|
23 |
+
return
|
24 |
+
|
25 |
+
|
26 |
+
def inference_on_location(model, latitude = 2.98, longitude = 48.81, start_date=2020, end_date=2022):
|
27 |
+
"""Performe an inference on the latitude and longitude between the start date and the end date
|
28 |
+
|
29 |
+
Args:
|
30 |
+
latitude (float): the latitude of the landscape
|
31 |
+
longitude (float): the longitude of the landscape
|
32 |
+
start_date (str): the start date for our inference
|
33 |
+
end_date (str): the end date for our inference
|
34 |
+
model (_type_, optional): _description_. Defaults to model.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
|
38 |
+
"""
|
39 |
+
assert end_date > start_date, "end date must be stricly higher than start date"
|
40 |
+
location = [float(latitude), float(longitude)]
|
41 |
+
|
42 |
+
# Extract img numpy from earth engine and transform it to PIL img
|
43 |
+
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
|
44 |
+
while dates[-1] < datetime.datetime(end_date, 1, 1, 0, 0, 0):
|
45 |
+
dates.append(dates[-1] + relativedelta(months=1))
|
46 |
+
|
47 |
+
dates = [d.strftime("%Y-%m-%d") for d in dates]
|
48 |
+
|
49 |
+
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
|
50 |
+
all_image = [image for image in all_image if image is not None]
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
# tensorize & normalize img
|
56 |
+
preprocess = T.Compose(
|
57 |
+
[
|
58 |
+
T.ToPILImage(),
|
59 |
+
T.Resize((320, 320)),
|
60 |
+
# T.CenterCrop(224),
|
61 |
+
T.ToTensor(),
|
62 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
63 |
+
]
|
64 |
+
)
|
65 |
+
|
66 |
+
# Preprocess opened img
|
67 |
+
x = torch.stack([preprocess(imag) for imag in all_image]).cpu()
|
68 |
+
|
69 |
+
# launch inference on cpu
|
70 |
+
# x = torch.unsqueeze(x, dim=0).cpu()
|
71 |
+
model = model.cpu()
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
feats, code = model.net(x)
|
75 |
+
linear_pred = model.linear_probe(x, code)
|
76 |
+
linear_pred = linear_pred.argmax(1)
|
77 |
+
outputs = [{
|
78 |
+
"img": torch.unsqueeze(img, dim=0).detach().cpu(),
|
79 |
+
"linear_preds": torch.unsqueeze(linear_pred, dim=0).detach().cpu(),
|
80 |
+
} for img, linear_pred in zip(x, linear_pred)]
|
81 |
+
all_img = []
|
82 |
+
all_label = []
|
83 |
+
all_labeled_img = []
|
84 |
+
for output in outputs:
|
85 |
+
img, label, labeled_img = transform_to_pil(output)
|
86 |
+
all_img.append(img)
|
87 |
+
all_label.append(label)
|
88 |
+
all_labeled_img.append(labeled_img)
|
89 |
+
|
90 |
+
all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_labeled_img]
|
91 |
+
create_video(all_labeled_img, output_path='output/output.mp4')
|
92 |
+
|
93 |
+
# all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_img]
|
94 |
+
# create_video(all_labeled_img, output_path='raw.mp4')
|
95 |
+
|
96 |
+
return 'output.mp4'
|
97 |
+
|
98 |
+
def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'):
|
99 |
+
"""Performe an inference on the latitude and longitude between the start date and the end date
|
100 |
+
|
101 |
+
Args:
|
102 |
+
latitude (float): the latitude of the landscape
|
103 |
+
longitude (float): the longitude of the landscape
|
104 |
+
start_date (str): the start date for our inference
|
105 |
+
end_date (str): the end date for our inference
|
106 |
+
model (_type_, optional): _description_. Defaults to model.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
|
110 |
+
"""
|
111 |
+
location = [float(latitude), float(longitude)]
|
112 |
+
|
113 |
+
# Extract img numpy from earth engine and transform it to PIL img
|
114 |
+
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
|
115 |
+
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
|
116 |
+
img = extract_img(location, start_date, end_date)
|
117 |
+
img_test = transform_ee_img(
|
118 |
+
img, max=0.3
|
119 |
+
) # max value is the value from numpy file that will be equal to 255
|
120 |
+
|
121 |
+
# tensorize & normalize img
|
122 |
+
preprocess = T.Compose(
|
123 |
+
[
|
124 |
+
T.ToPILImage(),
|
125 |
+
T.Resize((320, 320)),
|
126 |
+
# T.CenterCrop(224),
|
127 |
+
T.ToTensor(),
|
128 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
129 |
+
]
|
130 |
+
)
|
131 |
+
|
132 |
+
# Preprocess opened img
|
133 |
+
x = preprocess(img_test)
|
134 |
+
|
135 |
+
# launch inference on cpu
|
136 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
137 |
+
model = model.cpu()
|
138 |
+
|
139 |
+
with torch.no_grad():
|
140 |
+
feats, code = model.net(x)
|
141 |
+
linear_pred = model.linear_probe(x, code)
|
142 |
+
linear_pred = linear_pred.argmax(1)
|
143 |
+
output = {
|
144 |
+
"img": x[: model.cfg.n_images].detach().cpu(),
|
145 |
+
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
|
146 |
+
}
|
147 |
+
nb_values = []
|
148 |
+
for i in range(7):
|
149 |
+
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
|
150 |
+
scores_init = [2,3,4,3,1,4,0]
|
151 |
+
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
|
152 |
+
|
153 |
+
img, label, labeled_img = transform_to_pil(output)
|
154 |
+
return img, labeled_img,score
|
155 |
+
|
156 |
+
|
157 |
+
if __name__ == "__main__":
|
158 |
+
import logging
|
159 |
+
import hydra
|
160 |
+
|
161 |
+
|
162 |
+
from model import LitUnsupervisedSegmenter
|
163 |
+
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
|
164 |
+
# Initialize hydra with configs
|
165 |
+
hydra.initialize(config_path="configs", job_name="corine")
|
166 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
167 |
+
logging.info(f"config : {cfg}")
|
168 |
+
# Load the model
|
169 |
+
|
170 |
+
nbclasses = cfg.dir_dataset_n_classes
|
171 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
172 |
+
logging.info(f"Model Initialiazed")
|
173 |
+
|
174 |
+
model_path = "checkpoint/model/model.pt"
|
175 |
+
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
176 |
+
logging.info(f"Model weights Loaded")
|
177 |
+
model.load_state_dict(saved_state_dict)
|
178 |
+
logging.info(f"Model Loaded")
|
179 |
+
inference_on_location(model)
|
biomap/inference.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.multiprocessing
|
2 |
+
import torchvision.transforms as T
|
3 |
+
from utils import transform_to_pil
|
4 |
+
|
5 |
+
def inference(image, model):
|
6 |
+
# tensorize & normalize img
|
7 |
+
preprocess = T.Compose(
|
8 |
+
[
|
9 |
+
T.ToPILImage(),
|
10 |
+
T.Resize((320, 320)),
|
11 |
+
# T.CenterCrop(224),
|
12 |
+
T.ToTensor(),
|
13 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
14 |
+
]
|
15 |
+
)
|
16 |
+
|
17 |
+
# Preprocess opened img
|
18 |
+
x = preprocess(image)
|
19 |
+
|
20 |
+
# launch inference on cpu
|
21 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
22 |
+
model = model.cpu()
|
23 |
+
|
24 |
+
with torch.no_grad():
|
25 |
+
feats, code = model.net(x)
|
26 |
+
linear_pred = model.linear_probe(x, code)
|
27 |
+
linear_pred = linear_pred.argmax(1)
|
28 |
+
output = {
|
29 |
+
"img": x[: model.cfg.n_images].detach().cpu(),
|
30 |
+
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
|
31 |
+
}
|
32 |
+
|
33 |
+
img, label, labeled_img = transform_to_pil(output)
|
34 |
+
return img, labeled_img, label
|
35 |
+
|
36 |
+
|
37 |
+
if __name__ == "__main__":
|
38 |
+
import hydra
|
39 |
+
from model import LitUnsupervisedSegmenter
|
40 |
+
from utils_gee import extract_img, transform_ee_img
|
41 |
+
latitude = 2.98
|
42 |
+
longitude = 48.81
|
43 |
+
start_date = '2020-03-20'
|
44 |
+
end_date = '2020-04-20'
|
45 |
+
|
46 |
+
location = [float(latitude), float(longitude)]
|
47 |
+
# Extract img numpy from earth engine and transform it to PIL img
|
48 |
+
img = extract_img(location, start_date, end_date)
|
49 |
+
image = transform_ee_img(
|
50 |
+
img, max=0.3
|
51 |
+
) # max value is the value from numpy file that will be equal to 255
|
52 |
+
print("image loaded")
|
53 |
+
# Initialize hydra with configs
|
54 |
+
hydra.initialize(config_path="configs", job_name="corine")
|
55 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
56 |
+
|
57 |
+
# Load the model
|
58 |
+
model_path = "checkpoint/model/model.pt"
|
59 |
+
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
|
60 |
+
|
61 |
+
nbclasses = cfg.dir_dataset_n_classes
|
62 |
+
|
63 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
64 |
+
print("model initialized")
|
65 |
+
model.load_state_dict(saved_state_dict)
|
66 |
+
print("model loaded")
|
67 |
+
# img.save("output/image.png")
|
68 |
+
img, labeled_img, label = inference(image, model)
|
69 |
+
img.save("output/img.png")
|
70 |
+
label.save("output/label.png")
|
71 |
+
labeled_img.save("output/labeled_img.png")
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
# def get_list_date(start_date, end_date):
|
77 |
+
# """Get all the date between the start date and the end date
|
78 |
+
|
79 |
+
# Args:
|
80 |
+
# start_date (str): start date at the format '%Y-%m-%d'
|
81 |
+
# end_date (str): end date at the format '%Y-%m-%d'
|
82 |
+
|
83 |
+
# Returns:
|
84 |
+
# list[str]: all the date between the start date and the end date
|
85 |
+
# """
|
86 |
+
# start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
|
87 |
+
# end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
|
88 |
+
# list_date = [start_date]
|
89 |
+
# date = start_date
|
90 |
+
# while date < end_date:
|
91 |
+
# date = date + datetime.timedelta(days=1)
|
92 |
+
# list_date.append(date)
|
93 |
+
# list_date.append(end_date)
|
94 |
+
# list_date2 = [x.strftime("%Y-%m-%d") for x in list_date]
|
95 |
+
# return list_date2
|
96 |
+
|
97 |
+
|
98 |
+
# def get_length_interval(start_date, end_date):
|
99 |
+
# """Return how many days there is between the start date and the end date
|
100 |
+
|
101 |
+
# Args:
|
102 |
+
# start_date (str): start date at the format '%Y-%m-%d'
|
103 |
+
# end_date (str): end date at the format '%Y-%m-%d'
|
104 |
+
|
105 |
+
# Returns:
|
106 |
+
# int : number of days between start date and the end date
|
107 |
+
# """
|
108 |
+
# try:
|
109 |
+
# return len(get_list_date(start_date, end_date))
|
110 |
+
# except ValueError:
|
111 |
+
# return 0
|
112 |
+
|
113 |
+
|
114 |
+
# def infer_unique_date(latitude, longitude, date, model=model):
|
115 |
+
# """Perform an inference on a latitude and a longitude at a specific date
|
116 |
+
|
117 |
+
# Args:
|
118 |
+
# latitude (float): the latitude of the landscape
|
119 |
+
# longitude (float): the longitude of the landscape
|
120 |
+
# date (str): date for the inference at the format '%Y-%m-%d'
|
121 |
+
# model (_type_, optional): _description_. Defaults to model.
|
122 |
+
|
123 |
+
# Returns:
|
124 |
+
# img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
|
125 |
+
# """
|
126 |
+
# start_date = date
|
127 |
+
# end_date = date
|
128 |
+
# location = [float(latitude), float(longitude)]
|
129 |
+
# # Extract img numpy from earth engine and transform it to PIL img
|
130 |
+
# img = extract_img(location, start_date, end_date)
|
131 |
+
# img_test = transform_ee_img(
|
132 |
+
# img, max=0.3
|
133 |
+
# ) # max value is the value from numpy file that will be equal to 255
|
134 |
+
|
135 |
+
# # tensorize & normalize img
|
136 |
+
# preprocess = T.Compose(
|
137 |
+
# [
|
138 |
+
# T.ToPILImage(),
|
139 |
+
# T.Resize((320, 320)),
|
140 |
+
# # T.CenterCrop(224),
|
141 |
+
# T.ToTensor(),
|
142 |
+
# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
143 |
+
# ]
|
144 |
+
# )
|
145 |
+
|
146 |
+
# # Preprocess opened img
|
147 |
+
# x = preprocess(img_test)
|
148 |
+
|
149 |
+
# # launch inference on cpu
|
150 |
+
# x = torch.unsqueeze(x, dim=0).cpu()
|
151 |
+
# model = model.cpu()
|
152 |
+
|
153 |
+
# with torch.no_grad():
|
154 |
+
# feats, code = model.net(x)
|
155 |
+
# linear_pred = model.linear_probe(x, code)
|
156 |
+
# linear_pred = linear_pred.argmax(1)
|
157 |
+
# output = {
|
158 |
+
# "img": x[: model.cfg.n_images].detach().cpu(),
|
159 |
+
# "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
|
160 |
+
# }
|
161 |
+
|
162 |
+
# img, label, labeled_img = transform_to_pil(output)
|
163 |
+
# biodiv_score = compute_biodiv_score(labeled_img)
|
164 |
+
# return img, labeled_img, biodiv_score
|
165 |
+
|
166 |
+
|
167 |
+
# def get_img_array(start_date, end_date, latitude, longitude, model=model):
|
168 |
+
# list_date = get_list_date(start_date, end_date)
|
169 |
+
# list_img = []
|
170 |
+
# for date in list_date:
|
171 |
+
# list_img.append(img)
|
172 |
+
# return list_img
|
173 |
+
|
174 |
+
|
175 |
+
# def variable_outputs(start_date, end_date, latitude, longitude, day, model=model):
|
176 |
+
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected
|
177 |
+
|
178 |
+
# Args:
|
179 |
+
# latitude (float): the latitude of the landscape
|
180 |
+
# longitude (float): the longitude of the landscape
|
181 |
+
# start_date (str): the start date for our inference
|
182 |
+
# end_date (str): the end date for our inference
|
183 |
+
# model (_type_, optional): _description_. Defaults to model.
|
184 |
+
|
185 |
+
# Returns:
|
186 |
+
# img,labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
|
187 |
+
# """
|
188 |
+
# list_date = get_list_date(start_date, end_date)
|
189 |
+
# k = int(day)
|
190 |
+
# date = list_date[k]
|
191 |
+
# img, labeled_img, biodiv_score = infer_unique_date(
|
192 |
+
# latitude, longitude, date, model=model
|
193 |
+
# )
|
194 |
+
# return img, labeled_img, biodiv_score
|
195 |
+
|
196 |
+
|
197 |
+
# def variable_outputs2(
|
198 |
+
# start_date, end_date, latitude, longitude, day_number, model=model
|
199 |
+
# ):
|
200 |
+
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected
|
201 |
+
|
202 |
+
# Args:
|
203 |
+
# latitude (float): the latitude of the landscape
|
204 |
+
# longitude (float): the longitude of the landscape
|
205 |
+
# start_date (str): the start date for our inference
|
206 |
+
# end_date (str): the end date for our inference
|
207 |
+
# model (_type_, optional): _description_. Defaults to model.
|
208 |
+
|
209 |
+
# Returns:
|
210 |
+
# list[img,labeled_img,biodiv_score]: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
|
211 |
+
# """
|
212 |
+
# list_date = get_list_date(start_date, end_date)
|
213 |
+
# k = int(day_number)
|
214 |
+
# date = list_date[k]
|
215 |
+
# img, labeled_img, biodiv_score = infer_unique_date(
|
216 |
+
# latitude, longitude, date, model=model
|
217 |
+
# )
|
218 |
+
# return [img, labeled_img, biodiv_score]
|
219 |
+
|
220 |
+
|
221 |
+
# def gif_maker(img_array):
|
222 |
+
# output_file = "test2.mkv"
|
223 |
+
# image_test = img_array[0]
|
224 |
+
# size = (320, 320)
|
225 |
+
# print(size)
|
226 |
+
# out = cv2.VideoWriter(
|
227 |
+
# output_file, cv2.VideoWriter_fourcc(*"avc1"), 15, frameSize=size
|
228 |
+
# )
|
229 |
+
# for i in range(len(img_array)):
|
230 |
+
# image = img_array[i]
|
231 |
+
# pix = np.array(image.getdata())
|
232 |
+
# out.write(pix)
|
233 |
+
# out.release()
|
234 |
+
# return output_file
|
235 |
+
|
236 |
+
|
237 |
+
# def infer_multiple_date(start_date, end_date, latitude, longitude, model=model):
|
238 |
+
# """Perform an inference on all the dates between the start date and the end date at the latitude and longitude
|
239 |
+
|
240 |
+
# Args:
|
241 |
+
# latitude (float): the latitude of the landscape
|
242 |
+
# longitude (float): the longitude of the landscape
|
243 |
+
# start_date (str): the start date for our inference
|
244 |
+
# end_date (str): the end date for our inference
|
245 |
+
# model (_type_, optional): _description_. Defaults to model.
|
246 |
+
|
247 |
+
# Returns:
|
248 |
+
# list_img,list_labeled_img,list_biodiv_score: list of the original landscape, the labeled landscape and the biodiversity score and the landscape
|
249 |
+
# """
|
250 |
+
# list_date = get_list_date(start_date, end_date)
|
251 |
+
# list_img = []
|
252 |
+
# list_labeled_img = []
|
253 |
+
# list_biodiv_score = []
|
254 |
+
# for date in list_date:
|
255 |
+
# img, labeled_img, biodiv_score = infer_unique_date(
|
256 |
+
# latitude, longitude, date, model=model
|
257 |
+
# )
|
258 |
+
# list_img.append(img)
|
259 |
+
# list_labeled_img.append(labeled_img)
|
260 |
+
# list_biodiv_score.append(biodiv_score)
|
261 |
+
# return gif_maker(list_img), gif_maker(list_labeled_img), list_biodiv_score[0]
|
biomap/label.png
ADDED
![]() |
biomap/model.py
ADDED
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
from modules import *
|
3 |
+
from data import *
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import torch.multiprocessing
|
7 |
+
import seaborn as sns
|
8 |
+
import unet
|
9 |
+
|
10 |
+
class LitUnsupervisedSegmenter(pl.LightningModule):
|
11 |
+
def __init__(self, n_classes, cfg):
|
12 |
+
super().__init__()
|
13 |
+
self.cfg = cfg
|
14 |
+
self.n_classes = n_classes
|
15 |
+
|
16 |
+
if not cfg.continuous:
|
17 |
+
dim = n_classes
|
18 |
+
else:
|
19 |
+
dim = cfg.dim
|
20 |
+
|
21 |
+
data_dir = join(cfg.output_root, "data")
|
22 |
+
if cfg.arch == "feature-pyramid":
|
23 |
+
cut_model = load_model(cfg.model_type, data_dir).cuda()
|
24 |
+
self.net = FeaturePyramidNet(
|
25 |
+
cfg.granularity, cut_model, dim, cfg.continuous
|
26 |
+
)
|
27 |
+
elif cfg.arch == "dino":
|
28 |
+
self.net = DinoFeaturizer(dim, cfg)
|
29 |
+
else:
|
30 |
+
raise ValueError("Unknown arch {}".format(cfg.arch))
|
31 |
+
|
32 |
+
self.train_cluster_probe = ClusterLookup(dim, n_classes)
|
33 |
+
|
34 |
+
self.cluster_probe = ClusterLookup(dim, n_classes + cfg.extra_clusters)
|
35 |
+
# self.linear_probe = nn.Conv2d(dim, n_classes, (1, 1))
|
36 |
+
# self.linear_probe = nn.Sequential(OrderedDict([
|
37 |
+
# ('conv1', nn.Conv2d(dim, 2*n_classes, (7, 7), padding='same')),
|
38 |
+
# ('relu1', nn.ReLU()),
|
39 |
+
# ('conv2', nn.Conv2d(2*n_classes, n_classes, (3, 3), padding='same'))
|
40 |
+
# ]))
|
41 |
+
self.linear_probe = unet.AuxUNet(
|
42 |
+
enc_chs=(3, 32, 64, 128, 256),
|
43 |
+
dec_chs=(256, 128, 64, 32),
|
44 |
+
aux_ch=70,
|
45 |
+
num_class=n_classes,
|
46 |
+
)
|
47 |
+
|
48 |
+
self.decoder = nn.Conv2d(dim, self.net.n_feats, (1, 1))
|
49 |
+
|
50 |
+
self.cluster_metrics = UnsupervisedMetrics(
|
51 |
+
"test/cluster/", n_classes, cfg.extra_clusters, True
|
52 |
+
)
|
53 |
+
self.linear_metrics = UnsupervisedMetrics("test/linear/", n_classes, 0, False)
|
54 |
+
|
55 |
+
self.test_cluster_metrics = UnsupervisedMetrics(
|
56 |
+
"final/cluster/", n_classes, cfg.extra_clusters, True
|
57 |
+
)
|
58 |
+
self.test_linear_metrics = UnsupervisedMetrics(
|
59 |
+
"final/linear/", n_classes, 0, False
|
60 |
+
)
|
61 |
+
|
62 |
+
self.linear_probe_loss_fn = torch.nn.CrossEntropyLoss()
|
63 |
+
self.crf_loss_fn = ContrastiveCRFLoss(
|
64 |
+
cfg.crf_samples, cfg.alpha, cfg.beta, cfg.gamma, cfg.w1, cfg.w2, cfg.shift
|
65 |
+
)
|
66 |
+
|
67 |
+
self.contrastive_corr_loss_fn = ContrastiveCorrelationLoss(cfg)
|
68 |
+
for p in self.contrastive_corr_loss_fn.parameters():
|
69 |
+
p.requires_grad = False
|
70 |
+
|
71 |
+
self.automatic_optimization = False
|
72 |
+
|
73 |
+
if self.cfg.dataset_name.startswith("cityscapes"):
|
74 |
+
self.label_cmap = create_cityscapes_colormap()
|
75 |
+
else:
|
76 |
+
self.label_cmap = create_pascal_label_colormap()
|
77 |
+
|
78 |
+
self.val_steps = 0
|
79 |
+
self.save_hyperparameters()
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
# in lightning, forward defines the prediction/inference actions
|
83 |
+
return self.net(x)[1]
|
84 |
+
|
85 |
+
def training_step(self, batch, batch_idx):
|
86 |
+
# training_step defined the train loop.
|
87 |
+
# It is independent of forward
|
88 |
+
net_optim, linear_probe_optim, cluster_probe_optim = self.optimizers()
|
89 |
+
|
90 |
+
net_optim.zero_grad()
|
91 |
+
linear_probe_optim.zero_grad()
|
92 |
+
cluster_probe_optim.zero_grad()
|
93 |
+
|
94 |
+
with torch.no_grad():
|
95 |
+
ind = batch["ind"]
|
96 |
+
img = batch["img"]
|
97 |
+
img_aug = batch["img_aug"]
|
98 |
+
coord_aug = batch["coord_aug"]
|
99 |
+
img_pos = batch["img_pos"]
|
100 |
+
label = batch["label"]
|
101 |
+
label_pos = batch["label_pos"]
|
102 |
+
|
103 |
+
feats, code = self.net(img)
|
104 |
+
if self.cfg.correspondence_weight > 0:
|
105 |
+
feats_pos, code_pos = self.net(img_pos)
|
106 |
+
log_args = dict(sync_dist=False, rank_zero_only=True)
|
107 |
+
|
108 |
+
if self.cfg.use_true_labels:
|
109 |
+
signal = one_hot_feats(label + 1, self.n_classes + 1)
|
110 |
+
signal_pos = one_hot_feats(label_pos + 1, self.n_classes + 1)
|
111 |
+
else:
|
112 |
+
signal = feats
|
113 |
+
signal_pos = feats_pos
|
114 |
+
|
115 |
+
loss = 0
|
116 |
+
|
117 |
+
should_log_hist = (
|
118 |
+
(self.cfg.hist_freq is not None)
|
119 |
+
and (self.global_step % self.cfg.hist_freq == 0)
|
120 |
+
and (self.global_step > 0)
|
121 |
+
)
|
122 |
+
if self.cfg.use_salience:
|
123 |
+
salience = batch["mask"].to(torch.float32).squeeze(1)
|
124 |
+
salience_pos = batch["mask_pos"].to(torch.float32).squeeze(1)
|
125 |
+
else:
|
126 |
+
salience = None
|
127 |
+
salience_pos = None
|
128 |
+
|
129 |
+
if self.cfg.correspondence_weight > 0:
|
130 |
+
(
|
131 |
+
pos_intra_loss,
|
132 |
+
pos_intra_cd,
|
133 |
+
pos_inter_loss,
|
134 |
+
pos_inter_cd,
|
135 |
+
neg_inter_loss,
|
136 |
+
neg_inter_cd,
|
137 |
+
) = self.contrastive_corr_loss_fn(
|
138 |
+
signal,
|
139 |
+
signal_pos,
|
140 |
+
salience,
|
141 |
+
salience_pos,
|
142 |
+
code,
|
143 |
+
code_pos,
|
144 |
+
)
|
145 |
+
|
146 |
+
if should_log_hist:
|
147 |
+
self.logger.experiment.add_histogram(
|
148 |
+
"intra_cd", pos_intra_cd, self.global_step
|
149 |
+
)
|
150 |
+
self.logger.experiment.add_histogram(
|
151 |
+
"inter_cd", pos_inter_cd, self.global_step
|
152 |
+
)
|
153 |
+
self.logger.experiment.add_histogram(
|
154 |
+
"neg_cd", neg_inter_cd, self.global_step
|
155 |
+
)
|
156 |
+
neg_inter_loss = neg_inter_loss.mean()
|
157 |
+
pos_intra_loss = pos_intra_loss.mean()
|
158 |
+
pos_inter_loss = pos_inter_loss.mean()
|
159 |
+
self.log("loss/pos_intra", pos_intra_loss, **log_args)
|
160 |
+
self.log("loss/pos_inter", pos_inter_loss, **log_args)
|
161 |
+
self.log("loss/neg_inter", neg_inter_loss, **log_args)
|
162 |
+
self.log("cd/pos_intra", pos_intra_cd.mean(), **log_args)
|
163 |
+
self.log("cd/pos_inter", pos_inter_cd.mean(), **log_args)
|
164 |
+
self.log("cd/neg_inter", neg_inter_cd.mean(), **log_args)
|
165 |
+
|
166 |
+
loss += (
|
167 |
+
self.cfg.pos_inter_weight * pos_inter_loss
|
168 |
+
+ self.cfg.pos_intra_weight * pos_intra_loss
|
169 |
+
+ self.cfg.neg_inter_weight * neg_inter_loss
|
170 |
+
) * self.cfg.correspondence_weight
|
171 |
+
|
172 |
+
if self.cfg.rec_weight > 0:
|
173 |
+
rec_feats = self.decoder(code)
|
174 |
+
rec_loss = -(norm(rec_feats) * norm(feats)).sum(1).mean()
|
175 |
+
self.log("loss/rec", rec_loss, **log_args)
|
176 |
+
loss += self.cfg.rec_weight * rec_loss
|
177 |
+
|
178 |
+
if self.cfg.aug_alignment_weight > 0:
|
179 |
+
orig_feats_aug, orig_code_aug = self.net(img_aug)
|
180 |
+
downsampled_coord_aug = resize(
|
181 |
+
coord_aug.permute(0, 3, 1, 2), orig_code_aug.shape[2]
|
182 |
+
).permute(0, 2, 3, 1)
|
183 |
+
aug_alignment = -torch.einsum(
|
184 |
+
"bkhw,bkhw->bhw",
|
185 |
+
norm(sample(code, downsampled_coord_aug)),
|
186 |
+
norm(orig_code_aug),
|
187 |
+
).mean()
|
188 |
+
self.log("loss/aug_alignment", aug_alignment, **log_args)
|
189 |
+
loss += self.cfg.aug_alignment_weight * aug_alignment
|
190 |
+
|
191 |
+
if self.cfg.crf_weight > 0:
|
192 |
+
crf = self.crf_loss_fn(resize(img, 56), norm(resize(code, 56))).mean()
|
193 |
+
self.log("loss/crf", crf, **log_args)
|
194 |
+
loss += self.cfg.crf_weight * crf
|
195 |
+
|
196 |
+
flat_label = label.reshape(-1)
|
197 |
+
mask = (flat_label >= 0) & (flat_label < self.n_classes)
|
198 |
+
|
199 |
+
detached_code = torch.clone(code.detach())
|
200 |
+
|
201 |
+
# pdb.set_trace()
|
202 |
+
|
203 |
+
linear_logits = self.linear_probe(img, detached_code)
|
204 |
+
linear_logits = F.interpolate(
|
205 |
+
linear_logits, label.shape[-2:], mode="bilinear", align_corners=False
|
206 |
+
)
|
207 |
+
linear_logits = linear_logits.permute(0, 2, 3, 1).reshape(-1, self.n_classes)
|
208 |
+
linear_loss = self.linear_probe_loss_fn(
|
209 |
+
linear_logits[mask], flat_label[mask]
|
210 |
+
).mean()
|
211 |
+
loss += linear_loss
|
212 |
+
self.log("loss/linear", linear_loss, **log_args)
|
213 |
+
|
214 |
+
cluster_loss, cluster_probs = self.cluster_probe(detached_code, None)
|
215 |
+
loss += cluster_loss
|
216 |
+
self.log("loss/cluster", cluster_loss, **log_args)
|
217 |
+
self.log("loss/total", loss, **log_args)
|
218 |
+
|
219 |
+
self.manual_backward(loss)
|
220 |
+
net_optim.step()
|
221 |
+
cluster_probe_optim.step()
|
222 |
+
linear_probe_optim.step()
|
223 |
+
|
224 |
+
if (
|
225 |
+
self.cfg.reset_probe_steps is not None
|
226 |
+
and self.global_step == self.cfg.reset_probe_steps
|
227 |
+
):
|
228 |
+
print("RESETTING PROBES")
|
229 |
+
self.linear_probe.reset_parameters()
|
230 |
+
self.cluster_probe.reset_parameters()
|
231 |
+
self.trainer.optimizers[1] = torch.optim.Adam(
|
232 |
+
list(self.linear_probe.parameters()), lr=5e-3
|
233 |
+
)
|
234 |
+
self.trainer.optimizers[2] = torch.optim.Adam(
|
235 |
+
list(self.cluster_probe.parameters()), lr=5e-3
|
236 |
+
)
|
237 |
+
|
238 |
+
if self.global_step % 2000 == 0 and self.global_step > 0:
|
239 |
+
print("RESETTING TFEVENT FILE")
|
240 |
+
# Make a new tfevent file
|
241 |
+
self.logger.experiment.close()
|
242 |
+
self.logger.experiment._get_file_writer()
|
243 |
+
|
244 |
+
return loss
|
245 |
+
|
246 |
+
def on_train_start(self):
|
247 |
+
tb_metrics = {**self.linear_metrics.compute(), **self.cluster_metrics.compute()}
|
248 |
+
self.logger.log_hyperparams(self.cfg, tb_metrics)
|
249 |
+
|
250 |
+
def validation_step(self, batch, batch_idx):
|
251 |
+
img = batch["img"]
|
252 |
+
label = batch["label"]
|
253 |
+
self.net.eval()
|
254 |
+
|
255 |
+
with torch.no_grad():
|
256 |
+
feats, code = self.net(img)
|
257 |
+
|
258 |
+
# code = F.interpolate(code, label.shape[-2:], mode='bilinear', align_corners=False)
|
259 |
+
# linear_preds = self.linear_probe(code)
|
260 |
+
linear_preds = self.linear_probe(img, code)
|
261 |
+
linear_preds = linear_preds.argmax(1)
|
262 |
+
self.linear_metrics.update(linear_preds, label)
|
263 |
+
|
264 |
+
code = F.interpolate(
|
265 |
+
code, label.shape[-2:], mode="bilinear", align_corners=False
|
266 |
+
)
|
267 |
+
cluster_loss, cluster_preds = self.cluster_probe(code, None)
|
268 |
+
cluster_preds = cluster_preds.argmax(1)
|
269 |
+
self.cluster_metrics.update(cluster_preds, label)
|
270 |
+
|
271 |
+
return {
|
272 |
+
"img": img[: self.cfg.n_images].detach().cpu(),
|
273 |
+
"linear_preds": linear_preds[: self.cfg.n_images].detach().cpu(),
|
274 |
+
"cluster_preds": cluster_preds[: self.cfg.n_images].detach().cpu(),
|
275 |
+
"label": label[: self.cfg.n_images].detach().cpu(),
|
276 |
+
}
|
277 |
+
|
278 |
+
def validation_epoch_end(self, outputs) -> None:
|
279 |
+
super().validation_epoch_end(outputs)
|
280 |
+
with torch.no_grad():
|
281 |
+
tb_metrics = {
|
282 |
+
**self.linear_metrics.compute(),
|
283 |
+
**self.cluster_metrics.compute(),
|
284 |
+
}
|
285 |
+
|
286 |
+
if self.trainer.is_global_zero and not self.cfg.submitting_to_aml:
|
287 |
+
# output_num = 0
|
288 |
+
output_num = random.randint(0, len(outputs) - 1)
|
289 |
+
output = {k: v.detach().cpu() for k, v in outputs[output_num].items()}
|
290 |
+
|
291 |
+
# pdb.set_trace()
|
292 |
+
alpha = 0.4
|
293 |
+
n_rows = 6
|
294 |
+
fig, ax = plt.subplots(
|
295 |
+
n_rows,
|
296 |
+
self.cfg.n_images,
|
297 |
+
figsize=(self.cfg.n_images * 3, n_rows * 3),
|
298 |
+
)
|
299 |
+
for i in range(self.cfg.n_images):
|
300 |
+
try:
|
301 |
+
rbg_img = prep_for_plot(output["img"][i])
|
302 |
+
true_label = output["label"].squeeze()[i]
|
303 |
+
true_label[true_label == -1] = 7
|
304 |
+
except:
|
305 |
+
continue
|
306 |
+
# ax[0, i].imshow(prep_for_plot(output["img"][i]))
|
307 |
+
# ax[1, i].imshow(self.label_cmap[output["label"].squeeze()[i]])
|
308 |
+
# ax[2, i].imshow(self.label_cmap[output["linear_preds"][i]])
|
309 |
+
# ax[3, i].imshow(self.label_cmap[self.cluster_metrics.map_clusters(output["cluster_preds"][i])])
|
310 |
+
ax[0, i].imshow(rbg_img)
|
311 |
+
|
312 |
+
ax[1, i].imshow(rbg_img)
|
313 |
+
ax[1, i].imshow(true_label, alpha=alpha, cmap=cmap, norm=norm)
|
314 |
+
|
315 |
+
ax[2, i].imshow(rbg_img)
|
316 |
+
pred_label = output["linear_preds"][i]
|
317 |
+
ax[2, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
|
318 |
+
|
319 |
+
ax[3, i].imshow(rbg_img)
|
320 |
+
retouched_label = retouch_label(pred_label.numpy(), true_label)
|
321 |
+
ax[3, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
|
322 |
+
|
323 |
+
ax[4, i].imshow(rbg_img)
|
324 |
+
pred_label = self.cluster_metrics.map_clusters(
|
325 |
+
output["cluster_preds"][i]
|
326 |
+
)
|
327 |
+
ax[4, i].imshow(pred_label, alpha=alpha, cmap=cmap, norm=norm)
|
328 |
+
# ax[3, i].imshow(map_clusters_with_label(true_label, pred_label), alpha=0.5, cmap=cmap, norm=norm)
|
329 |
+
|
330 |
+
ax[5, i].imshow(rbg_img)
|
331 |
+
retouched_label = retouch_label(pred_label.numpy(), true_label)
|
332 |
+
ax[5, i].imshow(retouched_label, alpha=alpha, cmap=cmap, norm=norm)
|
333 |
+
|
334 |
+
ax[0, 0].set_ylabel("Image", fontsize=16)
|
335 |
+
ax[1, 0].set_ylabel("Label", fontsize=16)
|
336 |
+
ax[2, 0].set_ylabel("UNet Probe", fontsize=16)
|
337 |
+
ax[3, 0].set_ylabel("Retouched UNet Probe", fontsize=16)
|
338 |
+
ax[4, 0].set_ylabel("Cluster Probe", fontsize=16)
|
339 |
+
ax[5, 0].set_ylabel("Retouched cluster Probe", fontsize=16)
|
340 |
+
remove_axes(ax)
|
341 |
+
plt.tight_layout()
|
342 |
+
add_plot(self.logger.experiment, "plot_labels", self.global_step)
|
343 |
+
|
344 |
+
if self.cfg.has_labels:
|
345 |
+
fig = plt.figure(figsize=(13, 10))
|
346 |
+
ax = fig.gca()
|
347 |
+
hist = (
|
348 |
+
self.cluster_metrics.histogram.detach().cpu().to(torch.float32)
|
349 |
+
)
|
350 |
+
hist /= torch.clamp_min(hist.sum(dim=0, keepdim=True), 1)
|
351 |
+
sns.heatmap(hist.t(), annot=False, fmt="g", ax=ax, cmap="Blues")
|
352 |
+
ax.set_xlabel("Predicted labels")
|
353 |
+
ax.set_ylabel("True labels")
|
354 |
+
names = get_class_labels(self.cfg.dataset_name)
|
355 |
+
if self.cfg.extra_clusters:
|
356 |
+
names = names + ["Extra"]
|
357 |
+
ax.set_xticks(np.arange(0, len(names)) + 0.5)
|
358 |
+
ax.set_yticks(np.arange(0, len(names)) + 0.5)
|
359 |
+
ax.xaxis.tick_top()
|
360 |
+
ax.xaxis.set_ticklabels(names, fontsize=14)
|
361 |
+
ax.yaxis.set_ticklabels(names, fontsize=14)
|
362 |
+
colors = [self.label_cmap[i] / 255.0 for i in range(len(names))]
|
363 |
+
[
|
364 |
+
t.set_color(colors[i])
|
365 |
+
for i, t in enumerate(ax.xaxis.get_ticklabels())
|
366 |
+
]
|
367 |
+
[
|
368 |
+
t.set_color(colors[i])
|
369 |
+
for i, t in enumerate(ax.yaxis.get_ticklabels())
|
370 |
+
]
|
371 |
+
# ax.yaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
|
372 |
+
# ax.xaxis.get_ticklabels()[-1].set_color(self.label_cmap[0] / 255.0)
|
373 |
+
plt.xticks(rotation=90)
|
374 |
+
plt.yticks(rotation=0)
|
375 |
+
ax.vlines(
|
376 |
+
np.arange(0, len(names) + 1),
|
377 |
+
color=[0.5, 0.5, 0.5],
|
378 |
+
*ax.get_xlim()
|
379 |
+
)
|
380 |
+
ax.hlines(
|
381 |
+
np.arange(0, len(names) + 1),
|
382 |
+
color=[0.5, 0.5, 0.5],
|
383 |
+
*ax.get_ylim()
|
384 |
+
)
|
385 |
+
plt.tight_layout()
|
386 |
+
add_plot(self.logger.experiment, "conf_matrix", self.global_step)
|
387 |
+
|
388 |
+
all_bars = torch.cat(
|
389 |
+
[
|
390 |
+
self.cluster_metrics.histogram.sum(0).cpu(),
|
391 |
+
self.cluster_metrics.histogram.sum(1).cpu(),
|
392 |
+
],
|
393 |
+
axis=0,
|
394 |
+
)
|
395 |
+
ymin = max(all_bars.min() * 0.8, 1)
|
396 |
+
ymax = all_bars.max() * 1.2
|
397 |
+
|
398 |
+
fig, ax = plt.subplots(1, 2, figsize=(2 * 5, 1 * 4))
|
399 |
+
ax[0].bar(
|
400 |
+
range(self.n_classes + self.cfg.extra_clusters),
|
401 |
+
self.cluster_metrics.histogram.sum(0).cpu(),
|
402 |
+
tick_label=names,
|
403 |
+
color=colors,
|
404 |
+
)
|
405 |
+
ax[0].set_ylim(ymin, ymax)
|
406 |
+
ax[0].set_title("Label Frequency")
|
407 |
+
ax[0].set_yscale("log")
|
408 |
+
ax[0].tick_params(axis="x", labelrotation=90)
|
409 |
+
|
410 |
+
ax[1].bar(
|
411 |
+
range(self.n_classes + self.cfg.extra_clusters),
|
412 |
+
self.cluster_metrics.histogram.sum(1).cpu(),
|
413 |
+
tick_label=names,
|
414 |
+
color=colors,
|
415 |
+
)
|
416 |
+
ax[1].set_ylim(ymin, ymax)
|
417 |
+
ax[1].set_title("Cluster Frequency")
|
418 |
+
ax[1].set_yscale("log")
|
419 |
+
ax[1].tick_params(axis="x", labelrotation=90)
|
420 |
+
|
421 |
+
plt.tight_layout()
|
422 |
+
add_plot(
|
423 |
+
self.logger.experiment, "label frequency", self.global_step
|
424 |
+
)
|
425 |
+
|
426 |
+
if self.global_step > 2:
|
427 |
+
self.log_dict(tb_metrics)
|
428 |
+
|
429 |
+
if self.trainer.is_global_zero and self.cfg.azureml_logging:
|
430 |
+
from azureml.core.run import Run
|
431 |
+
|
432 |
+
run_logger = Run.get_context()
|
433 |
+
for metric, value in tb_metrics.items():
|
434 |
+
run_logger.log(metric, value)
|
435 |
+
|
436 |
+
self.linear_metrics.reset()
|
437 |
+
self.cluster_metrics.reset()
|
438 |
+
|
439 |
+
def configure_optimizers(self):
|
440 |
+
main_params = list(self.net.parameters())
|
441 |
+
|
442 |
+
if self.cfg.rec_weight > 0:
|
443 |
+
main_params.extend(self.decoder.parameters())
|
444 |
+
|
445 |
+
net_optim = torch.optim.Adam(main_params, lr=self.cfg.lr)
|
446 |
+
linear_probe_optim = torch.optim.Adam(
|
447 |
+
list(self.linear_probe.parameters()), lr=5e-3
|
448 |
+
)
|
449 |
+
cluster_probe_optim = torch.optim.Adam(
|
450 |
+
list(self.cluster_probe.parameters()), lr=5e-3
|
451 |
+
)
|
452 |
+
|
453 |
+
return net_optim, linear_probe_optim, cluster_probe_optim
|
biomap/modules.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from utils import *
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import dino.vision_transformer as vits
|
6 |
+
|
7 |
+
import pdb
|
8 |
+
|
9 |
+
class LambdaLayer(nn.Module):
|
10 |
+
def __init__(self, lambd):
|
11 |
+
super(LambdaLayer, self).__init__()
|
12 |
+
self.lambd = lambd
|
13 |
+
|
14 |
+
def forward(self, x):
|
15 |
+
return self.lambd(x)
|
16 |
+
|
17 |
+
|
18 |
+
class DinoFeaturizer(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self, dim, cfg):
|
21 |
+
super().__init__()
|
22 |
+
self.cfg = cfg
|
23 |
+
self.dim = dim
|
24 |
+
patch_size = self.cfg.dino_patch_size
|
25 |
+
self.patch_size = patch_size
|
26 |
+
self.feat_type = self.cfg.dino_feat_type
|
27 |
+
arch = self.cfg.model_type
|
28 |
+
self.model = vits.__dict__[arch](
|
29 |
+
patch_size=patch_size,
|
30 |
+
num_classes=0)
|
31 |
+
for p in self.model.parameters():
|
32 |
+
p.requires_grad = False
|
33 |
+
# pdb.set_trace()
|
34 |
+
self.model=self.model.cpu()
|
35 |
+
self.model.eval()
|
36 |
+
self.dropout = torch.nn.Dropout2d(p=.1)
|
37 |
+
|
38 |
+
if arch == "vit_small" and patch_size == 16:
|
39 |
+
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
|
40 |
+
elif arch == "vit_small" and patch_size == 8:
|
41 |
+
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth"
|
42 |
+
elif arch == "vit_base" and patch_size == 16:
|
43 |
+
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
|
44 |
+
elif arch == "vit_base" and patch_size == 8:
|
45 |
+
url = "dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth"
|
46 |
+
else:
|
47 |
+
raise ValueError("Unknown arch and patch size")
|
48 |
+
|
49 |
+
if cfg.pretrained_weights is not None:
|
50 |
+
state_dict = torch.load(cfg.pretrained_weights, map_location="cpu")
|
51 |
+
state_dict = state_dict["teacher"]
|
52 |
+
# remove `module.` prefix
|
53 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
54 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
55 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
56 |
+
|
57 |
+
# state_dict = {k.replace("projection_head", "mlp"): v for k, v in state_dict.items()}
|
58 |
+
# state_dict = {k.replace("prototypes", "last_layer"): v for k, v in state_dict.items()}
|
59 |
+
|
60 |
+
msg = self.model.load_state_dict(state_dict, strict=False)
|
61 |
+
print('Pretrained weights found at {} and loaded with msg: {}'.format(cfg.pretrained_weights, msg))
|
62 |
+
else:
|
63 |
+
print("Since no pretrained weights have been provided, we load the reference pretrained DINO weights.")
|
64 |
+
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/" + url)
|
65 |
+
self.model.load_state_dict(state_dict, strict=True)
|
66 |
+
|
67 |
+
if arch == "vit_small":
|
68 |
+
self.n_feats = 384
|
69 |
+
else:
|
70 |
+
self.n_feats = 768
|
71 |
+
self.cluster1 = self.make_clusterer(self.n_feats)
|
72 |
+
self.proj_type = cfg.projection_type
|
73 |
+
if self.proj_type == "nonlinear":
|
74 |
+
self.cluster2 = self.make_nonlinear_clusterer(self.n_feats)
|
75 |
+
|
76 |
+
def make_clusterer(self, in_channels):
|
77 |
+
return torch.nn.Sequential(
|
78 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1))) # ,
|
79 |
+
|
80 |
+
def make_nonlinear_clusterer(self, in_channels):
|
81 |
+
return torch.nn.Sequential(
|
82 |
+
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
|
83 |
+
torch.nn.ReLU(),
|
84 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1)))
|
85 |
+
|
86 |
+
def forward(self, img, n=1, return_class_feat=False):
|
87 |
+
self.model.eval()
|
88 |
+
with torch.no_grad():
|
89 |
+
assert (img.shape[2] % self.patch_size == 0)
|
90 |
+
assert (img.shape[3] % self.patch_size == 0)
|
91 |
+
|
92 |
+
# get selected layer activations
|
93 |
+
feat, attn, qkv = self.model.get_intermediate_feat(img, n=n)
|
94 |
+
feat, attn, qkv = feat[0], attn[0], qkv[0]
|
95 |
+
|
96 |
+
feat_h = img.shape[2] // self.patch_size
|
97 |
+
feat_w = img.shape[3] // self.patch_size
|
98 |
+
|
99 |
+
if self.feat_type == "feat":
|
100 |
+
image_feat = feat[:, 1:, :].reshape(feat.shape[0], feat_h, feat_w, -1).permute(0, 3, 1, 2)
|
101 |
+
elif self.feat_type == "KK":
|
102 |
+
image_k = qkv[1, :, :, 1:, :].reshape(feat.shape[0], 6, feat_h, feat_w, -1)
|
103 |
+
B, H, I, J, D = image_k.shape
|
104 |
+
image_feat = image_k.permute(0, 1, 4, 2, 3).reshape(B, H * D, I, J)
|
105 |
+
else:
|
106 |
+
raise ValueError("Unknown feat type:{}".format(self.feat_type))
|
107 |
+
|
108 |
+
if return_class_feat:
|
109 |
+
return feat[:, :1, :].reshape(feat.shape[0], 1, 1, -1).permute(0, 3, 1, 2)
|
110 |
+
|
111 |
+
if self.proj_type is not None:
|
112 |
+
code = self.cluster1(self.dropout(image_feat))
|
113 |
+
if self.proj_type == "nonlinear":
|
114 |
+
code += self.cluster2(self.dropout(image_feat))
|
115 |
+
else:
|
116 |
+
code = image_feat
|
117 |
+
|
118 |
+
if self.cfg.dropout:
|
119 |
+
return self.dropout(image_feat), code
|
120 |
+
else:
|
121 |
+
return image_feat, code
|
122 |
+
|
123 |
+
|
124 |
+
class ResizeAndClassify(nn.Module):
|
125 |
+
|
126 |
+
def __init__(self, dim: int, size: int, n_classes: int):
|
127 |
+
super(ResizeAndClassify, self).__init__()
|
128 |
+
self.size = size
|
129 |
+
self.predictor = torch.nn.Sequential(
|
130 |
+
torch.nn.Conv2d(dim, n_classes, (1, 1)),
|
131 |
+
torch.nn.LogSoftmax(1))
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
return F.interpolate(self.predictor.forward(x), self.size, mode="bilinear", align_corners=False)
|
135 |
+
|
136 |
+
|
137 |
+
class ClusterLookup(nn.Module):
|
138 |
+
|
139 |
+
def __init__(self, dim: int, n_classes: int):
|
140 |
+
super(ClusterLookup, self).__init__()
|
141 |
+
self.n_classes = n_classes
|
142 |
+
self.dim = dim
|
143 |
+
self.clusters = torch.nn.Parameter(torch.randn(n_classes, dim))
|
144 |
+
|
145 |
+
def reset_parameters(self):
|
146 |
+
with torch.no_grad():
|
147 |
+
self.clusters.copy_(torch.randn(self.n_classes, self.dim))
|
148 |
+
|
149 |
+
def forward(self, x, alpha, log_probs=False):
|
150 |
+
normed_clusters = F.normalize(self.clusters, dim=1)
|
151 |
+
normed_features = F.normalize(x, dim=1)
|
152 |
+
inner_products = torch.einsum("bchw,nc->bnhw", normed_features, normed_clusters)
|
153 |
+
|
154 |
+
if alpha is None:
|
155 |
+
cluster_probs = F.one_hot(torch.argmax(inner_products, dim=1), self.clusters.shape[0]) \
|
156 |
+
.permute(0, 3, 1, 2).to(torch.float32)
|
157 |
+
else:
|
158 |
+
cluster_probs = nn.functional.softmax(inner_products * alpha, dim=1)
|
159 |
+
|
160 |
+
cluster_loss = -(cluster_probs * inner_products).sum(1).mean()
|
161 |
+
if log_probs:
|
162 |
+
return nn.functional.log_softmax(inner_products * alpha, dim=1)
|
163 |
+
else:
|
164 |
+
return cluster_loss, cluster_probs
|
165 |
+
|
166 |
+
|
167 |
+
class FeaturePyramidNet(nn.Module):
|
168 |
+
|
169 |
+
@staticmethod
|
170 |
+
def _helper(x):
|
171 |
+
# TODO remove this hard coded 56
|
172 |
+
return F.interpolate(x, 56, mode="bilinear", align_corners=False).unsqueeze(-1)
|
173 |
+
|
174 |
+
def make_clusterer(self, in_channels):
|
175 |
+
return torch.nn.Sequential(
|
176 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
|
177 |
+
LambdaLayer(FeaturePyramidNet._helper))
|
178 |
+
|
179 |
+
def make_nonlinear_clusterer(self, in_channels):
|
180 |
+
return torch.nn.Sequential(
|
181 |
+
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
|
182 |
+
torch.nn.ReLU(),
|
183 |
+
torch.nn.Conv2d(in_channels, in_channels, (1, 1)),
|
184 |
+
torch.nn.ReLU(),
|
185 |
+
torch.nn.Conv2d(in_channels, self.dim, (1, 1)),
|
186 |
+
LambdaLayer(FeaturePyramidNet._helper))
|
187 |
+
|
188 |
+
def __init__(self, granularity, cut_model, dim, continuous):
|
189 |
+
super(FeaturePyramidNet, self).__init__()
|
190 |
+
self.layer_nums = [5, 6, 7]
|
191 |
+
self.spatial_resolutions = [7, 14, 28, 56]
|
192 |
+
self.feat_channels = [2048, 1024, 512, 3]
|
193 |
+
self.extra_channels = [128, 64, 32, 32]
|
194 |
+
self.granularity = granularity
|
195 |
+
self.encoder = NetWithActivations(cut_model, self.layer_nums)
|
196 |
+
self.dim = dim
|
197 |
+
self.continuous = continuous
|
198 |
+
self.n_feats = self.dim
|
199 |
+
|
200 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
|
201 |
+
|
202 |
+
assert granularity in {1, 2, 3, 4}
|
203 |
+
self.cluster1 = self.make_clusterer(self.feat_channels[0])
|
204 |
+
self.cluster1_nl = self.make_nonlinear_clusterer(self.feat_channels[0])
|
205 |
+
|
206 |
+
if granularity >= 2:
|
207 |
+
# self.conv1 = DoubleConv(self.feat_channels[0], self.extra_channels[0])
|
208 |
+
# self.conv2 = DoubleConv(self.extra_channels[0] + self.feat_channels[1], self.extra_channels[1])
|
209 |
+
self.conv2 = DoubleConv(self.feat_channels[0] + self.feat_channels[1], self.extra_channels[1])
|
210 |
+
self.cluster2 = self.make_clusterer(self.extra_channels[1])
|
211 |
+
if granularity >= 3:
|
212 |
+
self.conv3 = DoubleConv(self.extra_channels[1] + self.feat_channels[2], self.extra_channels[2])
|
213 |
+
self.cluster3 = self.make_clusterer(self.extra_channels[2])
|
214 |
+
if granularity >= 4:
|
215 |
+
self.conv4 = DoubleConv(self.extra_channels[2] + self.feat_channels[3], self.extra_channels[3])
|
216 |
+
self.cluster4 = self.make_clusterer(self.extra_channels[3])
|
217 |
+
|
218 |
+
def c(self, x, y):
|
219 |
+
return torch.cat([x, y], dim=1)
|
220 |
+
|
221 |
+
def forward(self, x):
|
222 |
+
with torch.no_grad():
|
223 |
+
feats = self.encoder(x)
|
224 |
+
low_res_feats = feats[self.layer_nums[-1]]
|
225 |
+
|
226 |
+
all_clusters = []
|
227 |
+
|
228 |
+
# all_clusters.append(self.cluster1(low_res_feats) + self.cluster1_nl(low_res_feats))
|
229 |
+
all_clusters.append(self.cluster1(low_res_feats))
|
230 |
+
|
231 |
+
if self.granularity >= 2:
|
232 |
+
# f1 = self.conv1(low_res_feats)
|
233 |
+
# f1_up = self.up(f1)
|
234 |
+
f1_up = self.up(low_res_feats)
|
235 |
+
f2 = self.conv2(self.c(f1_up, feats[self.layer_nums[-2]]))
|
236 |
+
all_clusters.append(self.cluster2(f2))
|
237 |
+
if self.granularity >= 3:
|
238 |
+
f2_up = self.up(f2)
|
239 |
+
f3 = self.conv3(self.c(f2_up, feats[self.layer_nums[-3]]))
|
240 |
+
all_clusters.append(self.cluster3(f3))
|
241 |
+
if self.granularity >= 4:
|
242 |
+
f3_up = self.up(f3)
|
243 |
+
final_size = self.spatial_resolutions[-1]
|
244 |
+
f4 = self.conv4(self.c(f3_up, F.interpolate(
|
245 |
+
x, (final_size, final_size), mode="bilinear", align_corners=False)))
|
246 |
+
all_clusters.append(self.cluster4(f4))
|
247 |
+
|
248 |
+
avg_code = torch.cat(all_clusters, 4).mean(4)
|
249 |
+
|
250 |
+
if self.continuous:
|
251 |
+
clusters = avg_code
|
252 |
+
else:
|
253 |
+
clusters = torch.log_softmax(avg_code, 1)
|
254 |
+
|
255 |
+
return low_res_feats, clusters
|
256 |
+
|
257 |
+
|
258 |
+
class DoubleConv(nn.Module):
|
259 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
260 |
+
|
261 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
262 |
+
super().__init__()
|
263 |
+
if not mid_channels:
|
264 |
+
mid_channels = out_channels
|
265 |
+
self.double_conv = nn.Sequential(
|
266 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
267 |
+
nn.BatchNorm2d(mid_channels),
|
268 |
+
nn.ReLU(),
|
269 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
270 |
+
nn.BatchNorm2d(out_channels),
|
271 |
+
nn.ReLU()
|
272 |
+
)
|
273 |
+
|
274 |
+
def forward(self, x):
|
275 |
+
return self.double_conv(x)
|
276 |
+
|
277 |
+
|
278 |
+
def norm(t):
|
279 |
+
return F.normalize(t, dim=1, eps=1e-10)
|
280 |
+
|
281 |
+
|
282 |
+
def average_norm(t):
|
283 |
+
return t / t.square().sum(1, keepdim=True).sqrt().mean()
|
284 |
+
|
285 |
+
|
286 |
+
def tensor_correlation(a, b):
|
287 |
+
return torch.einsum("nchw,ncij->nhwij", a, b)
|
288 |
+
|
289 |
+
|
290 |
+
def sample(t: torch.Tensor, coords: torch.Tensor):
|
291 |
+
return F.grid_sample(t, coords.permute(0, 2, 1, 3), padding_mode='border', align_corners=True)
|
292 |
+
|
293 |
+
|
294 |
+
@torch.jit.script
|
295 |
+
def super_perm(size: int, device: torch.device):
|
296 |
+
perm = torch.randperm(size, device=device, dtype=torch.long)
|
297 |
+
perm[perm == torch.arange(size, device=device)] += 1
|
298 |
+
return perm % size
|
299 |
+
|
300 |
+
|
301 |
+
def sample_nonzero_locations(t, target_size):
|
302 |
+
nonzeros = torch.nonzero(t)
|
303 |
+
coords = torch.zeros(target_size, dtype=nonzeros.dtype, device=nonzeros.device)
|
304 |
+
n = target_size[1] * target_size[2]
|
305 |
+
for i in range(t.shape[0]):
|
306 |
+
selected_nonzeros = nonzeros[nonzeros[:, 0] == i]
|
307 |
+
if selected_nonzeros.shape[0] == 0:
|
308 |
+
selected_coords = torch.randint(t.shape[1], size=(n, 2), device=nonzeros.device)
|
309 |
+
else:
|
310 |
+
selected_coords = selected_nonzeros[torch.randint(len(selected_nonzeros), size=(n,)), 1:]
|
311 |
+
coords[i, :, :, :] = selected_coords.reshape(target_size[1], target_size[2], 2)
|
312 |
+
coords = coords.to(torch.float32) / t.shape[1]
|
313 |
+
coords = coords * 2 - 1
|
314 |
+
return torch.flip(coords, dims=[-1])
|
315 |
+
|
316 |
+
|
317 |
+
class ContrastiveCorrelationLoss(nn.Module):
|
318 |
+
|
319 |
+
def __init__(self, cfg, ):
|
320 |
+
super(ContrastiveCorrelationLoss, self).__init__()
|
321 |
+
self.cfg = cfg
|
322 |
+
|
323 |
+
def standard_scale(self, t):
|
324 |
+
t1 = t - t.mean()
|
325 |
+
t2 = t1 / t1.std()
|
326 |
+
return t2
|
327 |
+
|
328 |
+
def helper(self, f1, f2, c1, c2, shift):
|
329 |
+
with torch.no_grad():
|
330 |
+
# Comes straight from backbone which is currently frozen. this saves mem.
|
331 |
+
fd = tensor_correlation(norm(f1), norm(f2))
|
332 |
+
|
333 |
+
if self.cfg.pointwise:
|
334 |
+
old_mean = fd.mean()
|
335 |
+
fd -= fd.mean([3, 4], keepdim=True)
|
336 |
+
fd = fd - fd.mean() + old_mean
|
337 |
+
|
338 |
+
cd = tensor_correlation(norm(c1), norm(c2))
|
339 |
+
|
340 |
+
if self.cfg.zero_clamp:
|
341 |
+
min_val = 0.0
|
342 |
+
else:
|
343 |
+
min_val = -9999.0
|
344 |
+
|
345 |
+
if self.cfg.stabalize:
|
346 |
+
loss = - cd.clamp(min_val, .8) * (fd - shift)
|
347 |
+
else:
|
348 |
+
loss = - cd.clamp(min_val) * (fd - shift)
|
349 |
+
|
350 |
+
return loss, cd
|
351 |
+
|
352 |
+
def forward(self,
|
353 |
+
orig_feats: torch.Tensor, orig_feats_pos: torch.Tensor,
|
354 |
+
orig_salience: torch.Tensor, orig_salience_pos: torch.Tensor,
|
355 |
+
orig_code: torch.Tensor, orig_code_pos: torch.Tensor,
|
356 |
+
):
|
357 |
+
|
358 |
+
coord_shape = [orig_feats.shape[0], self.cfg.feature_samples, self.cfg.feature_samples, 2]
|
359 |
+
|
360 |
+
if self.cfg.use_salience:
|
361 |
+
coords1_nonzero = sample_nonzero_locations(orig_salience, coord_shape)
|
362 |
+
coords2_nonzero = sample_nonzero_locations(orig_salience_pos, coord_shape)
|
363 |
+
coords1_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
364 |
+
coords2_reg = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
365 |
+
mask = (torch.rand(coord_shape[:-1], device=orig_feats.device) > .1).unsqueeze(-1).to(torch.float32)
|
366 |
+
coords1 = coords1_nonzero * mask + coords1_reg * (1 - mask)
|
367 |
+
coords2 = coords2_nonzero * mask + coords2_reg * (1 - mask)
|
368 |
+
else:
|
369 |
+
coords1 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
370 |
+
coords2 = torch.rand(coord_shape, device=orig_feats.device) * 2 - 1
|
371 |
+
|
372 |
+
feats = sample(orig_feats, coords1)
|
373 |
+
code = sample(orig_code, coords1)
|
374 |
+
|
375 |
+
feats_pos = sample(orig_feats_pos, coords2)
|
376 |
+
code_pos = sample(orig_code_pos, coords2)
|
377 |
+
|
378 |
+
pos_intra_loss, pos_intra_cd = self.helper(
|
379 |
+
feats, feats, code, code, self.cfg.pos_intra_shift)
|
380 |
+
pos_inter_loss, pos_inter_cd = self.helper(
|
381 |
+
feats, feats_pos, code, code_pos, self.cfg.pos_inter_shift)
|
382 |
+
|
383 |
+
neg_losses = []
|
384 |
+
neg_cds = []
|
385 |
+
for i in range(self.cfg.neg_samples):
|
386 |
+
perm_neg = super_perm(orig_feats.shape[0], orig_feats.device)
|
387 |
+
feats_neg = sample(orig_feats[perm_neg], coords2)
|
388 |
+
code_neg = sample(orig_code[perm_neg], coords2)
|
389 |
+
neg_inter_loss, neg_inter_cd = self.helper(
|
390 |
+
feats, feats_neg, code, code_neg, self.cfg.neg_inter_shift)
|
391 |
+
neg_losses.append(neg_inter_loss)
|
392 |
+
neg_cds.append(neg_inter_cd)
|
393 |
+
neg_inter_loss = torch.cat(neg_losses, axis=0)
|
394 |
+
neg_inter_cd = torch.cat(neg_cds, axis=0)
|
395 |
+
|
396 |
+
return (pos_intra_loss.mean(),
|
397 |
+
pos_intra_cd,
|
398 |
+
pos_inter_loss.mean(),
|
399 |
+
pos_inter_cd,
|
400 |
+
neg_inter_loss,
|
401 |
+
neg_inter_cd)
|
402 |
+
|
403 |
+
|
404 |
+
class Decoder(nn.Module):
|
405 |
+
def __init__(self, code_channels, feat_channels):
|
406 |
+
super().__init__()
|
407 |
+
self.linear = torch.nn.Conv2d(code_channels, feat_channels, (1, 1))
|
408 |
+
self.nonlinear = torch.nn.Sequential(
|
409 |
+
torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
|
410 |
+
torch.nn.ReLU(),
|
411 |
+
torch.nn.Conv2d(code_channels, code_channels, (1, 1)),
|
412 |
+
torch.nn.ReLU(),
|
413 |
+
torch.nn.Conv2d(code_channels, feat_channels, (1, 1)))
|
414 |
+
|
415 |
+
def forward(self, x):
|
416 |
+
return self.linear(x) + self.nonlinear(x)
|
417 |
+
|
418 |
+
|
419 |
+
class NetWithActivations(torch.nn.Module):
|
420 |
+
def __init__(self, model, layer_nums):
|
421 |
+
super(NetWithActivations, self).__init__()
|
422 |
+
self.layers = nn.ModuleList(model.children())
|
423 |
+
self.layer_nums = []
|
424 |
+
for l in layer_nums:
|
425 |
+
if l < 0:
|
426 |
+
self.layer_nums.append(len(self.layers) + l)
|
427 |
+
else:
|
428 |
+
self.layer_nums.append(l)
|
429 |
+
self.layer_nums = set(sorted(self.layer_nums))
|
430 |
+
|
431 |
+
def forward(self, x):
|
432 |
+
activations = {}
|
433 |
+
for ln, l in enumerate(self.layers):
|
434 |
+
x = l(x)
|
435 |
+
if ln in self.layer_nums:
|
436 |
+
activations[ln] = x
|
437 |
+
return activations
|
438 |
+
|
439 |
+
|
440 |
+
class ContrastiveCRFLoss(nn.Module):
|
441 |
+
|
442 |
+
def __init__(self, n_samples, alpha, beta, gamma, w1, w2, shift):
|
443 |
+
super(ContrastiveCRFLoss, self).__init__()
|
444 |
+
self.alpha = alpha
|
445 |
+
self.beta = beta
|
446 |
+
self.gamma = gamma
|
447 |
+
self.w1 = w1
|
448 |
+
self.w2 = w2
|
449 |
+
self.n_samples = n_samples
|
450 |
+
self.shift = shift
|
451 |
+
|
452 |
+
def forward(self, guidance, clusters):
|
453 |
+
device = clusters.device
|
454 |
+
assert (guidance.shape[0] == clusters.shape[0])
|
455 |
+
assert (guidance.shape[2:] == clusters.shape[2:])
|
456 |
+
h = guidance.shape[2]
|
457 |
+
w = guidance.shape[3]
|
458 |
+
|
459 |
+
coords = torch.cat([
|
460 |
+
torch.randint(0, h, size=[1, self.n_samples], device=device),
|
461 |
+
torch.randint(0, w, size=[1, self.n_samples], device=device)], 0)
|
462 |
+
|
463 |
+
selected_guidance = guidance[:, :, coords[0, :], coords[1, :]]
|
464 |
+
coord_diff = (coords.unsqueeze(-1) - coords.unsqueeze(1)).square().sum(0).unsqueeze(0)
|
465 |
+
guidance_diff = (selected_guidance.unsqueeze(-1) - selected_guidance.unsqueeze(2)).square().sum(1)
|
466 |
+
|
467 |
+
sim_kernel = self.w1 * torch.exp(- coord_diff / (2 * self.alpha) - guidance_diff / (2 * self.beta)) + \
|
468 |
+
self.w2 * torch.exp(- coord_diff / (2 * self.gamma)) - self.shift
|
469 |
+
|
470 |
+
selected_clusters = clusters[:, :, coords[0, :], coords[1, :]]
|
471 |
+
cluster_sims = torch.einsum("nka,nkb->nab", selected_clusters, selected_clusters)
|
472 |
+
return -(cluster_sims * sim_kernel)
|
biomap/output/img.png
ADDED
![]() |
biomap/output/img_6.png
ADDED
![]() |
biomap/output/label.png
ADDED
![]() |
biomap/output/labeled_img.png
ADDED
![]() |
biomap/plot_functions.py
ADDED
@@ -0,0 +1,778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
import hydra
|
4 |
+
import matplotlib as mpl
|
5 |
+
from utils import prep_for_plot
|
6 |
+
|
7 |
+
import torch.multiprocessing
|
8 |
+
import torchvision.transforms as T
|
9 |
+
# import matplotlib.pyplot as plt
|
10 |
+
from model import LitUnsupervisedSegmenter
|
11 |
+
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
|
12 |
+
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
|
13 |
+
cmap = mpl.colors.ListedColormap(colors)
|
14 |
+
#from train_segmentation import LitUnsupervisedSegmenter, cmap
|
15 |
+
|
16 |
+
from utils_gee import extract_img, transform_ee_img
|
17 |
+
|
18 |
+
import plotly.graph_objects as go
|
19 |
+
import plotly.express as px
|
20 |
+
import numpy as np
|
21 |
+
from plotly.subplots import make_subplots
|
22 |
+
|
23 |
+
import os
|
24 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
25 |
+
|
26 |
+
|
27 |
+
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
|
28 |
+
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
|
29 |
+
scores_init = [2,3,4,3,1,4,0]
|
30 |
+
|
31 |
+
# Import model configs
|
32 |
+
hydra.initialize(config_path="configs", job_name="corine")
|
33 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
34 |
+
|
35 |
+
nbclasses = cfg.dir_dataset_n_classes
|
36 |
+
|
37 |
+
# Load Model
|
38 |
+
model_path = "checkpoint/model/model.pt"
|
39 |
+
saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
|
40 |
+
|
41 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
42 |
+
model.load_state_dict(saved_state_dict)
|
43 |
+
|
44 |
+
from PIL import Image
|
45 |
+
|
46 |
+
import hydra
|
47 |
+
|
48 |
+
from utils import prep_for_plot
|
49 |
+
|
50 |
+
import torch.multiprocessing
|
51 |
+
import torchvision.transforms as T
|
52 |
+
# import matplotlib.pyplot as plt
|
53 |
+
|
54 |
+
from model import LitUnsupervisedSegmenter
|
55 |
+
|
56 |
+
from utils_gee import extract_img, transform_ee_img
|
57 |
+
|
58 |
+
import plotly.graph_objects as go
|
59 |
+
import plotly.express as px
|
60 |
+
import numpy as np
|
61 |
+
from plotly.subplots import make_subplots
|
62 |
+
|
63 |
+
import os
|
64 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
65 |
+
|
66 |
+
|
67 |
+
colors = ('red', 'palegreen', 'green', 'steelblue', 'blue', 'yellow', 'lightgrey')
|
68 |
+
cmap = mpl.colors.ListedColormap(colors)
|
69 |
+
class_names = ('Buildings', 'Cultivation', 'Natural green', 'Wetland', 'Water', 'Infrastructure', 'Background')
|
70 |
+
scores_init = [2,3,4,3,1,4,0]
|
71 |
+
|
72 |
+
# Import model configs
|
73 |
+
#hydra.initialize(config_path="configs", job_name="corine")
|
74 |
+
cfg = hydra.compose(config_name="my_train_config.yml")
|
75 |
+
|
76 |
+
nbclasses = cfg.dir_dataset_n_classes
|
77 |
+
|
78 |
+
# Load Model
|
79 |
+
model_path = "checkpoint/model/model.pt"
|
80 |
+
saved_state_dict = torch.load(model_path,map_location=torch.device('cpu'))
|
81 |
+
|
82 |
+
model = LitUnsupervisedSegmenter(nbclasses, cfg)
|
83 |
+
model.load_state_dict(saved_state_dict)
|
84 |
+
|
85 |
+
|
86 |
+
#normalize img
|
87 |
+
preprocess = T.Compose([
|
88 |
+
T.ToPILImage(),
|
89 |
+
T.Resize((320,320)),
|
90 |
+
# T.CenterCrop(224),
|
91 |
+
T.ToTensor(),
|
92 |
+
T.Normalize(
|
93 |
+
mean=[0.485, 0.456, 0.406],
|
94 |
+
std=[0.229, 0.224, 0.225]
|
95 |
+
)
|
96 |
+
])
|
97 |
+
|
98 |
+
# Function that look for img on EE and segment it
|
99 |
+
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
|
100 |
+
|
101 |
+
def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
|
102 |
+
if how == 'month':
|
103 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
|
104 |
+
elif how == 'year' :
|
105 |
+
if year_end == None :
|
106 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
107 |
+
else :
|
108 |
+
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
109 |
+
|
110 |
+
|
111 |
+
img_test= transform_ee_img(img, max = 0.25)
|
112 |
+
|
113 |
+
# Preprocess opened img
|
114 |
+
x = preprocess(img_test)
|
115 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
116 |
+
# model=model.cpu()
|
117 |
+
|
118 |
+
with torch.no_grad():
|
119 |
+
feats, code = model.net(x)
|
120 |
+
linear_preds = model.linear_probe(x, code)
|
121 |
+
linear_preds = linear_preds.argmax(1)
|
122 |
+
outputs = {
|
123 |
+
'img': x[:model.cfg.n_images].detach().cpu(),
|
124 |
+
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
|
125 |
+
}
|
126 |
+
return outputs
|
127 |
+
|
128 |
+
|
129 |
+
# Function that look for all img on EE and extract all segments with the date as first output arg
|
130 |
+
|
131 |
+
def segment_group(location, start_date, end_date, how = 'month') :
|
132 |
+
outputs = []
|
133 |
+
st_month = int(start_date[5:7])
|
134 |
+
end_month = int(end_date[5:7])
|
135 |
+
|
136 |
+
st_year = int(start_date[0:4])
|
137 |
+
end_year = int(end_date[0:4])
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
for year in range(st_year, end_year+1) :
|
142 |
+
|
143 |
+
if year != end_year :
|
144 |
+
last = 12
|
145 |
+
else :
|
146 |
+
last = end_month
|
147 |
+
|
148 |
+
if year != st_year:
|
149 |
+
start = 1
|
150 |
+
else :
|
151 |
+
start = st_month
|
152 |
+
|
153 |
+
if how == 'month' :
|
154 |
+
for month in range(start, last + 1):
|
155 |
+
month_str = f"{month:0>2d}"
|
156 |
+
year_str = str(year)
|
157 |
+
|
158 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
|
159 |
+
|
160 |
+
elif how == 'year' :
|
161 |
+
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
|
162 |
+
|
163 |
+
elif how == '2months' :
|
164 |
+
for month in range(start, last + 1):
|
165 |
+
month_str = f"{month:0>2d}"
|
166 |
+
year_str = str(year)
|
167 |
+
month_end = (month) % 12 +1
|
168 |
+
if month_end < month :
|
169 |
+
year_end = year +1
|
170 |
+
else :
|
171 |
+
year_end = year
|
172 |
+
month_end= f"{month_end:0>2d}"
|
173 |
+
year_end = str(year_end)
|
174 |
+
|
175 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
|
176 |
+
|
177 |
+
|
178 |
+
return outputs
|
179 |
+
|
180 |
+
|
181 |
+
# Function that transforms an output to PIL images
|
182 |
+
|
183 |
+
def transform_to_pil(outputs,alpha=0.3):
|
184 |
+
# Transform img with torch
|
185 |
+
img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
|
186 |
+
img=T.ToPILImage()(img)
|
187 |
+
|
188 |
+
# Transform label by saving it then open it
|
189 |
+
# label = outputs['linear_preds'][0]
|
190 |
+
# plt.imsave('label.png',label,cmap=cmap)
|
191 |
+
# label = Image.open('label.png')
|
192 |
+
|
193 |
+
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
|
194 |
+
labels = np.array(outputs['linear_preds'][0])-1
|
195 |
+
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
|
196 |
+
|
197 |
+
|
198 |
+
# Overlay labels with img wit alpha
|
199 |
+
background = img.convert("RGBA")
|
200 |
+
overlay = label.convert("RGBA")
|
201 |
+
|
202 |
+
labeled_img = Image.blend(background, overlay, alpha)
|
203 |
+
|
204 |
+
return img, label, labeled_img
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
# Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
|
209 |
+
|
210 |
+
def values_from_output(output):
|
211 |
+
imgs = transform_to_pil(output,alpha = 0.3)
|
212 |
+
|
213 |
+
img = imgs[0]
|
214 |
+
img = np.array(img.convert('RGB'))
|
215 |
+
|
216 |
+
labeled_img = imgs[2]
|
217 |
+
labeled_img = np.array(labeled_img.convert('RGB'))
|
218 |
+
|
219 |
+
nb_values = []
|
220 |
+
for i in range(7):
|
221 |
+
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
|
222 |
+
|
223 |
+
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
|
224 |
+
|
225 |
+
return img, labeled_img, nb_values, score
|
226 |
+
|
227 |
+
|
228 |
+
# Function that extract from outputs (from segment_group function) all dates/ all images
|
229 |
+
def values_from_outputs(outputs) :
|
230 |
+
months = []
|
231 |
+
imgs = []
|
232 |
+
imgs_label = []
|
233 |
+
nb_values = []
|
234 |
+
scores = []
|
235 |
+
|
236 |
+
for output in outputs:
|
237 |
+
img, labeled_img, nb_value, score = values_from_output(output[1])
|
238 |
+
months.append(output[0])
|
239 |
+
imgs.append(img)
|
240 |
+
imgs_label.append(labeled_img)
|
241 |
+
nb_values.append(nb_value)
|
242 |
+
scores.append(score)
|
243 |
+
|
244 |
+
return months, imgs, imgs_label, nb_values, scores
|
245 |
+
|
246 |
+
|
247 |
+
|
248 |
+
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
|
249 |
+
|
250 |
+
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
|
251 |
+
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
|
252 |
+
|
253 |
+
# Scores
|
254 |
+
scatters = []
|
255 |
+
temp = []
|
256 |
+
for score in scores :
|
257 |
+
temp_score = []
|
258 |
+
temp_date = []
|
259 |
+
score = scores[i]
|
260 |
+
temp.append(score)
|
261 |
+
text_temp = ["" for i in temp]
|
262 |
+
text_temp[-1] = str(round(score,2))
|
263 |
+
scatters.append(go.Scatter(x=text_temp, y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
|
264 |
+
|
265 |
+
|
266 |
+
# Scores
|
267 |
+
fig = make_subplots(
|
268 |
+
rows=1, cols=4,
|
269 |
+
# specs=[[{"rowspan": 2}, {"rowspan": 2}, {"type": "pie"}, None]]
|
270 |
+
# row_heights=[0.8, 0.2],
|
271 |
+
column_widths = [0.6, 0.6,0.3, 0.3],
|
272 |
+
subplot_titles=("Localisation visualization", "labeled visualisation", "Segments repartition", "Biodiversity scores")
|
273 |
+
)
|
274 |
+
|
275 |
+
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
|
276 |
+
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
|
277 |
+
|
278 |
+
fig.add_trace(go.Pie(labels = class_names,
|
279 |
+
values = nb_values[0],
|
280 |
+
marker_colors = colors,
|
281 |
+
name="Segment repartition",
|
282 |
+
textposition='inside',
|
283 |
+
texttemplate = "%{percent:.0%}",
|
284 |
+
textfont_size=14
|
285 |
+
),
|
286 |
+
row=1, col=3)
|
287 |
+
|
288 |
+
|
289 |
+
fig.add_trace(scatters[0], row=1, col=4)
|
290 |
+
# fig.add_annotation(text='score:' + str(scores[0]),
|
291 |
+
# showarrow=False,
|
292 |
+
# row=2, col=2)
|
293 |
+
|
294 |
+
|
295 |
+
number_frames = len(imgs)
|
296 |
+
frames = [dict(
|
297 |
+
name = k,
|
298 |
+
data = [ fig2["frames"][k]["data"][0],
|
299 |
+
fig3["frames"][k]["data"][0],
|
300 |
+
go.Pie(labels = class_names,
|
301 |
+
values = nb_values[k],
|
302 |
+
marker_colors = colors,
|
303 |
+
name="Segment repartition",
|
304 |
+
textposition='inside',
|
305 |
+
texttemplate = "%{percent:.0%}",
|
306 |
+
textfont_size=14
|
307 |
+
),
|
308 |
+
scatters[k]
|
309 |
+
],
|
310 |
+
traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
|
311 |
+
# that are updated by the above three go.Scatter instances
|
312 |
+
) for k in range(number_frames)]
|
313 |
+
|
314 |
+
updatemenus = [dict(type='buttons',
|
315 |
+
buttons=[dict(label='Play',
|
316 |
+
method='animate',
|
317 |
+
args=[[f'{k}' for k in range(number_frames)],
|
318 |
+
dict(frame=dict(duration=500, redraw=False),
|
319 |
+
transition=dict(duration=0),
|
320 |
+
easing='linear',
|
321 |
+
fromcurrent=True,
|
322 |
+
mode='immediate'
|
323 |
+
)])],
|
324 |
+
direction= 'left',
|
325 |
+
pad=dict(r= 10, t=85),
|
326 |
+
showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
|
327 |
+
]
|
328 |
+
|
329 |
+
sliders = [{'yanchor': 'top',
|
330 |
+
'xanchor': 'left',
|
331 |
+
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
|
332 |
+
'transition': {'duration': 500.0, 'easing': 'linear'},
|
333 |
+
'pad': {'b': 10, 't': 50},
|
334 |
+
'len': 0.9, 'x': 0.1, 'y': 0,
|
335 |
+
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
|
336 |
+
'transition': {'duration': 0, 'easing': 'linear'}}],
|
337 |
+
'label': months[k], 'method': 'animate'} for k in range(number_frames)
|
338 |
+
]}]
|
339 |
+
|
340 |
+
|
341 |
+
fig.update(frames=frames)
|
342 |
+
|
343 |
+
for i,fr in enumerate(fig["frames"]):
|
344 |
+
fr.update(
|
345 |
+
layout={
|
346 |
+
"xaxis": {
|
347 |
+
"range": [0,imgs[0].shape[1]+i/100000]
|
348 |
+
},
|
349 |
+
"yaxis": {
|
350 |
+
"range": [imgs[0].shape[0]+i/100000,0]
|
351 |
+
},
|
352 |
+
})
|
353 |
+
|
354 |
+
fr.update(layout_title_text= months[i])
|
355 |
+
|
356 |
+
|
357 |
+
fig.update(layout_title_text= 'tot')
|
358 |
+
fig.update(
|
359 |
+
layout={
|
360 |
+
"xaxis": {
|
361 |
+
"range": [0,imgs[0].shape[1]+i/100000],
|
362 |
+
'showgrid': False, # thin lines in the background
|
363 |
+
'zeroline': False, # thick line at x=0
|
364 |
+
'visible': False, # numbers below
|
365 |
+
},
|
366 |
+
|
367 |
+
"yaxis": {
|
368 |
+
"range": [imgs[0].shape[0]+i/100000,0],
|
369 |
+
'showgrid': False, # thin lines in the background
|
370 |
+
'zeroline': False, # thick line at y=0
|
371 |
+
'visible': False,},
|
372 |
+
|
373 |
+
"xaxis3": {
|
374 |
+
"range": [0,len(scores)+1],
|
375 |
+
'autorange': False, # thin lines in the background
|
376 |
+
'showgrid': False, # thin lines in the background
|
377 |
+
'zeroline': False, # thick line at y=0
|
378 |
+
'visible': False
|
379 |
+
},
|
380 |
+
|
381 |
+
"yaxis3": {
|
382 |
+
"range": [0,1.5],
|
383 |
+
'autorange': False,
|
384 |
+
'showgrid': False, # thin lines in the background
|
385 |
+
'zeroline': False, # thick line at y=0
|
386 |
+
'visible': False # thin lines in the background
|
387 |
+
}
|
388 |
+
},
|
389 |
+
legend=dict(
|
390 |
+
yanchor="bottom",
|
391 |
+
y=0.99,
|
392 |
+
xanchor="center",
|
393 |
+
x=0.01
|
394 |
+
)
|
395 |
+
)
|
396 |
+
|
397 |
+
|
398 |
+
fig.update_layout(updatemenus=updatemenus,
|
399 |
+
sliders=sliders)
|
400 |
+
|
401 |
+
fig.update_layout(margin=dict(b=0, r=0))
|
402 |
+
|
403 |
+
# fig.show() #in jupyter notebook
|
404 |
+
|
405 |
+
return fig
|
406 |
+
|
407 |
+
|
408 |
+
|
409 |
+
# Last function (global one)
|
410 |
+
# how = 'month' or '2months' or 'year'
|
411 |
+
|
412 |
+
def segment_region(location, start_date, end_date, how = 'month'):
|
413 |
+
|
414 |
+
#extract the outputs for each image
|
415 |
+
outputs = segment_group(location, start_date, end_date, how = how)
|
416 |
+
|
417 |
+
#extract the intersting values from image
|
418 |
+
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
|
419 |
+
|
420 |
+
#Create the figure
|
421 |
+
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
|
422 |
+
|
423 |
+
return fig
|
424 |
+
#normalize img
|
425 |
+
preprocess = T.Compose([
|
426 |
+
T.ToPILImage(),
|
427 |
+
T.Resize((320,320)),
|
428 |
+
# T.CenterCrop(224),
|
429 |
+
T.ToTensor(),
|
430 |
+
T.Normalize(
|
431 |
+
mean=[0.485, 0.456, 0.406],
|
432 |
+
std=[0.229, 0.224, 0.225]
|
433 |
+
)
|
434 |
+
])
|
435 |
+
|
436 |
+
# Function that look for img on EE and segment it
|
437 |
+
# -- 3 ways possible to avoid cloudy environment -- monthly / bi-monthly / yearly meaned img
|
438 |
+
|
439 |
+
def segment_loc(location, month, year, how = "month", month_end = '12', year_end = None) :
|
440 |
+
if how == 'month':
|
441 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month +'-28')
|
442 |
+
elif how == 'year' :
|
443 |
+
if year_end == None :
|
444 |
+
img = extract_img(location, year +'-'+ month +'-01', year +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
445 |
+
else :
|
446 |
+
img = extract_img(location, year +'-'+ month +'-01', year_end +'-'+ month_end +'-28', width = 0.04 , len = 0.04)
|
447 |
+
|
448 |
+
|
449 |
+
img_test= transform_ee_img(img, max = 0.25)
|
450 |
+
|
451 |
+
# Preprocess opened img
|
452 |
+
x = preprocess(img_test)
|
453 |
+
x = torch.unsqueeze(x, dim=0).cpu()
|
454 |
+
# model=model.cpu()
|
455 |
+
|
456 |
+
with torch.no_grad():
|
457 |
+
feats, code = model.net(x)
|
458 |
+
linear_preds = model.linear_probe(x, code)
|
459 |
+
linear_preds = linear_preds.argmax(1)
|
460 |
+
outputs = {
|
461 |
+
'img': x[:model.cfg.n_images].detach().cpu(),
|
462 |
+
'linear_preds': linear_preds[:model.cfg.n_images].detach().cpu()
|
463 |
+
}
|
464 |
+
return outputs
|
465 |
+
|
466 |
+
|
467 |
+
# Function that look for all img on EE and extract all segments with the date as first output arg
|
468 |
+
|
469 |
+
def segment_group(location, start_date, end_date, how = 'month') :
|
470 |
+
outputs = []
|
471 |
+
st_month = int(start_date[5:7])
|
472 |
+
end_month = int(end_date[5:7])
|
473 |
+
|
474 |
+
st_year = int(start_date[0:4])
|
475 |
+
end_year = int(end_date[0:4])
|
476 |
+
|
477 |
+
|
478 |
+
|
479 |
+
for year in range(st_year, end_year+1) :
|
480 |
+
|
481 |
+
if year != end_year :
|
482 |
+
last = 12
|
483 |
+
else :
|
484 |
+
last = end_month
|
485 |
+
|
486 |
+
if year != st_year:
|
487 |
+
start = 1
|
488 |
+
else :
|
489 |
+
start = st_month
|
490 |
+
|
491 |
+
if how == 'month' :
|
492 |
+
for month in range(start, last + 1):
|
493 |
+
month_str = f"{month:0>2d}"
|
494 |
+
year_str = str(year)
|
495 |
+
|
496 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str)))
|
497 |
+
|
498 |
+
elif how == 'year' :
|
499 |
+
outputs.append((str(year) + '-' + f"{start:0>2d}", segment_loc(location, f"{start:0>2d}", str(year), how = 'year', month_end=f"{last:0>2d}")))
|
500 |
+
|
501 |
+
elif how == '2months' :
|
502 |
+
for month in range(start, last + 1):
|
503 |
+
month_str = f"{month:0>2d}"
|
504 |
+
year_str = str(year)
|
505 |
+
month_end = (month) % 12 +1
|
506 |
+
if month_end < month :
|
507 |
+
year_end = year +1
|
508 |
+
else :
|
509 |
+
year_end = year
|
510 |
+
month_end= f"{month_end:0>2d}"
|
511 |
+
year_end = str(year_end)
|
512 |
+
|
513 |
+
outputs.append((year_str + '-' + month_str, segment_loc(location, month_str, year_str,how = 'year', month_end=month_end, year_end=year_end)))
|
514 |
+
|
515 |
+
|
516 |
+
return outputs
|
517 |
+
|
518 |
+
|
519 |
+
# Function that transforms an output to PIL images
|
520 |
+
|
521 |
+
def transform_to_pil(outputs,alpha=0.3):
|
522 |
+
# Transform img with torch
|
523 |
+
img = torch.moveaxis(prep_for_plot(outputs['img'][0]),-1,0)
|
524 |
+
img=T.ToPILImage()(img)
|
525 |
+
|
526 |
+
# Transform label by saving it then open it
|
527 |
+
# label = outputs['linear_preds'][0]
|
528 |
+
# plt.imsave('label.png',label,cmap=cmap)
|
529 |
+
# label = Image.open('label.png')
|
530 |
+
|
531 |
+
cmaplist = np.array([np.array(cmap(i)) for i in range(cmap.N)])
|
532 |
+
labels = np.array(outputs['linear_preds'][0])-1
|
533 |
+
label = T.ToPILImage()((cmaplist[labels]*255).astype(np.uint8))
|
534 |
+
|
535 |
+
|
536 |
+
# Overlay labels with img wit alpha
|
537 |
+
background = img.convert("RGBA")
|
538 |
+
overlay = label.convert("RGBA")
|
539 |
+
|
540 |
+
labeled_img = Image.blend(background, overlay, alpha)
|
541 |
+
|
542 |
+
return img, label, labeled_img
|
543 |
+
|
544 |
+
def values_from_output(output):
|
545 |
+
imgs = transform_to_pil(output,alpha = 0.3)
|
546 |
+
|
547 |
+
img = imgs[0]
|
548 |
+
img = np.array(img.convert('RGB'))
|
549 |
+
|
550 |
+
labeled_img = imgs[2]
|
551 |
+
labeled_img = np.array(labeled_img.convert('RGB'))
|
552 |
+
|
553 |
+
nb_values = []
|
554 |
+
for i in range(7):
|
555 |
+
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
|
556 |
+
|
557 |
+
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
|
558 |
+
|
559 |
+
return img, labeled_img, nb_values, score
|
560 |
+
|
561 |
+
|
562 |
+
# Function that extract labeled_img(PIL) and nb_values(number of pixels for each class) and the score for each observation
|
563 |
+
|
564 |
+
|
565 |
+
|
566 |
+
# Function that extract from outputs (from segment_group function) all dates/ all images
|
567 |
+
def values_from_outputs(outputs) :
|
568 |
+
months = []
|
569 |
+
imgs = []
|
570 |
+
imgs_label = []
|
571 |
+
nb_values = []
|
572 |
+
scores = []
|
573 |
+
|
574 |
+
for output in outputs:
|
575 |
+
img, labeled_img, nb_value, score = values_from_output(output[1])
|
576 |
+
months.append(output[0])
|
577 |
+
imgs.append(img)
|
578 |
+
imgs_label.append(labeled_img)
|
579 |
+
nb_values.append(nb_value)
|
580 |
+
scores.append(score)
|
581 |
+
|
582 |
+
return months, imgs, imgs_label, nb_values, scores
|
583 |
+
|
584 |
+
|
585 |
+
|
586 |
+
def plot_imgs_labels(months, imgs, imgs_label, nb_values, scores) :
|
587 |
+
|
588 |
+
fig2 = px.imshow(np.array(imgs), animation_frame=0, binary_string=True)
|
589 |
+
fig3 = px.imshow(np.array(imgs_label), animation_frame=0, binary_string=True)
|
590 |
+
|
591 |
+
# Scores
|
592 |
+
scatters = []
|
593 |
+
temp = []
|
594 |
+
for score in scores :
|
595 |
+
temp_score = []
|
596 |
+
temp_date = []
|
597 |
+
#score = scores[i]
|
598 |
+
temp.append(score)
|
599 |
+
n = len(temp)
|
600 |
+
text_temp = ["" for i in temp]
|
601 |
+
text_temp[-1] = str(round(score,2))
|
602 |
+
scatters.append(go.Scatter(x=[0,1], y=temp, mode="lines+markers+text", marker_color="black", text = text_temp, textposition="top center"))
|
603 |
+
print(text_temp)
|
604 |
+
|
605 |
+
# Scores
|
606 |
+
fig = make_subplots(
|
607 |
+
rows=1, cols=4,
|
608 |
+
specs=[[{"type": "image"},{"type": "image"}, {"type": "pie"}, {"type": "scatter"}]],
|
609 |
+
subplot_titles=("Localisation visualization", "Labeled visualisation", "Segments repartition", "Biodiversity scores")
|
610 |
+
)
|
611 |
+
|
612 |
+
fig.add_trace(fig2["frames"][0]["data"][0], row=1, col=1)
|
613 |
+
fig.add_trace(fig3["frames"][0]["data"][0], row=1, col=2)
|
614 |
+
|
615 |
+
fig.add_trace(go.Pie(labels = class_names,
|
616 |
+
values = nb_values[0],
|
617 |
+
marker_colors = colors,
|
618 |
+
name="Segment repartition",
|
619 |
+
textposition='inside',
|
620 |
+
texttemplate = "%{percent:.0%}",
|
621 |
+
textfont_size=14
|
622 |
+
),
|
623 |
+
row=1, col=3)
|
624 |
+
|
625 |
+
|
626 |
+
fig.add_trace(scatters[0], row=1, col=4)
|
627 |
+
fig.update_traces(showlegend=False, selector=dict(type='scatter'))
|
628 |
+
#fig.update_traces(, selector=dict(type='scatter'))
|
629 |
+
# fig.add_annotation(text='score:' + str(scores[0]),
|
630 |
+
# showarrow=False,
|
631 |
+
# row=2, col=2)
|
632 |
+
|
633 |
+
|
634 |
+
number_frames = len(imgs)
|
635 |
+
frames = [dict(
|
636 |
+
name = k,
|
637 |
+
data = [ fig2["frames"][k]["data"][0],
|
638 |
+
fig3["frames"][k]["data"][0],
|
639 |
+
go.Pie(labels = class_names,
|
640 |
+
values = nb_values[k],
|
641 |
+
marker_colors = colors,
|
642 |
+
name="Segment repartition",
|
643 |
+
textposition='inside',
|
644 |
+
texttemplate = "%{percent:.0%}",
|
645 |
+
textfont_size=14
|
646 |
+
),
|
647 |
+
scatters[k]
|
648 |
+
],
|
649 |
+
traces=[0, 1,2,3] # the elements of the list [0,1,2] give info on the traces in fig.data
|
650 |
+
# that are updated by the above three go.Scatter instances
|
651 |
+
) for k in range(number_frames)]
|
652 |
+
|
653 |
+
updatemenus = [dict(type='buttons',
|
654 |
+
buttons=[dict(label='Play',
|
655 |
+
method='animate',
|
656 |
+
args=[[f'{k}' for k in range(number_frames)],
|
657 |
+
dict(frame=dict(duration=500, redraw=False),
|
658 |
+
transition=dict(duration=0),
|
659 |
+
easing='linear',
|
660 |
+
fromcurrent=True,
|
661 |
+
mode='immediate'
|
662 |
+
)])],
|
663 |
+
direction= 'left',
|
664 |
+
pad=dict(r= 10, t=85),
|
665 |
+
showactive =True, x= 0.1, y= 0.13, xanchor= 'right', yanchor= 'top')
|
666 |
+
]
|
667 |
+
|
668 |
+
sliders = [{'yanchor': 'top',
|
669 |
+
'xanchor': 'left',
|
670 |
+
'currentvalue': {'font': {'size': 16}, 'prefix': 'Frame: ', 'visible': False, 'xanchor': 'right'},
|
671 |
+
'transition': {'duration': 500.0, 'easing': 'linear'},
|
672 |
+
'pad': {'b': 10, 't': 50},
|
673 |
+
'len': 0.9, 'x': 0.1, 'y': 0,
|
674 |
+
'steps': [{'args': [[k], {'frame': {'duration': 500.0, 'easing': 'linear', 'redraw': False},
|
675 |
+
'transition': {'duration': 0, 'easing': 'linear'}}],
|
676 |
+
'label': months[k], 'method': 'animate'} for k in range(number_frames)
|
677 |
+
]}]
|
678 |
+
|
679 |
+
|
680 |
+
fig.update(frames=frames)
|
681 |
+
|
682 |
+
for i,fr in enumerate(fig["frames"]):
|
683 |
+
fr.update(
|
684 |
+
layout={
|
685 |
+
"xaxis": {
|
686 |
+
"range": [0,imgs[0].shape[1]+i/100000]
|
687 |
+
},
|
688 |
+
"yaxis": {
|
689 |
+
"range": [imgs[0].shape[0]+i/100000,0]
|
690 |
+
},
|
691 |
+
})
|
692 |
+
|
693 |
+
fr.update(layout_title_text= months[i])
|
694 |
+
|
695 |
+
|
696 |
+
fig.update(layout_title_text= months[0])
|
697 |
+
fig.update(
|
698 |
+
layout={
|
699 |
+
"xaxis": {
|
700 |
+
"range": [0,imgs[0].shape[1]+i/100000],
|
701 |
+
'showgrid': False, # thin lines in the background
|
702 |
+
'zeroline': False, # thick line at x=0
|
703 |
+
'visible': False, # numbers below
|
704 |
+
},
|
705 |
+
|
706 |
+
"yaxis": {
|
707 |
+
"range": [imgs[0].shape[0]+i/100000,0],
|
708 |
+
'showgrid': False, # thin lines in the background
|
709 |
+
'zeroline': False, # thick line at y=0
|
710 |
+
'visible': False,},
|
711 |
+
|
712 |
+
"xaxis2": {
|
713 |
+
"range": [0,imgs[0].shape[1]+i/100000],
|
714 |
+
'showgrid': False, # thin lines in the background
|
715 |
+
'zeroline': False, # thick line at x=0
|
716 |
+
'visible': False, # numbers below
|
717 |
+
},
|
718 |
+
|
719 |
+
"yaxis2": {
|
720 |
+
"range": [imgs[0].shape[0]+i/100000,0],
|
721 |
+
'showgrid': False, # thin lines in the background
|
722 |
+
'zeroline': False, # thick line at y=0
|
723 |
+
'visible': False,},
|
724 |
+
|
725 |
+
|
726 |
+
"xaxis3": {
|
727 |
+
"range": [0,len(scores)+1],
|
728 |
+
'autorange': False, # thin lines in the background
|
729 |
+
'showgrid': False, # thin lines in the background
|
730 |
+
'zeroline': False, # thick line at y=0
|
731 |
+
'visible': False
|
732 |
+
},
|
733 |
+
|
734 |
+
"yaxis3": {
|
735 |
+
"range": [0,1.5],
|
736 |
+
'autorange': False,
|
737 |
+
'showgrid': False, # thin lines in the background
|
738 |
+
'zeroline': False, # thick line at y=0
|
739 |
+
'visible': False # thin lines in the background
|
740 |
+
}
|
741 |
+
}
|
742 |
+
)
|
743 |
+
|
744 |
+
|
745 |
+
fig.update_layout(updatemenus=updatemenus,
|
746 |
+
sliders=sliders,
|
747 |
+
legend=dict(
|
748 |
+
yanchor= 'top',
|
749 |
+
xanchor= 'left',
|
750 |
+
orientation="h")
|
751 |
+
)
|
752 |
+
|
753 |
+
|
754 |
+
fig.update_layout(margin=dict(b=0, r=0))
|
755 |
+
|
756 |
+
# fig.show() #in jupyter notebook
|
757 |
+
|
758 |
+
return fig
|
759 |
+
|
760 |
+
|
761 |
+
|
762 |
+
# Last function (global one)
|
763 |
+
# how = 'month' or '2months' or 'year'
|
764 |
+
|
765 |
+
def segment_region(latitude, longitude, start_date, end_date, how = 'month'):
|
766 |
+
location = [float(latitude),float(longitude)]
|
767 |
+
how = how[0]
|
768 |
+
#extract the outputs for each image
|
769 |
+
outputs = segment_group(location, start_date, end_date, how = how)
|
770 |
+
|
771 |
+
#extract the intersting values from image
|
772 |
+
months, imgs, imgs_label, nb_values, scores = values_from_outputs(outputs)
|
773 |
+
|
774 |
+
|
775 |
+
#Create the figure
|
776 |
+
fig = plot_imgs_labels(months, imgs, imgs_label, nb_values, scores)
|
777 |
+
|
778 |
+
return fig
|
biomap/train.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import *
|
2 |
+
from modules import *
|
3 |
+
from data import *
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from datetime import datetime
|
7 |
+
import hydra
|
8 |
+
from omegaconf import DictConfig, OmegaConf
|
9 |
+
import pytorch_lightning as pl
|
10 |
+
from pytorch_lightning import Trainer
|
11 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
12 |
+
from pytorch_lightning.utilities.seed import seed_everything
|
13 |
+
import torch.multiprocessing
|
14 |
+
import seaborn as sns
|
15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
16 |
+
import sys
|
17 |
+
import pdb
|
18 |
+
import matplotlib as mpl
|
19 |
+
from skimage import measure
|
20 |
+
from scipy.stats import mode as statsmode
|
21 |
+
from collections import OrderedDict
|
22 |
+
import unet
|
23 |
+
import pdb
|
24 |
+
|
25 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
26 |
+
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
|
27 |
+
class_names = (
|
28 |
+
"Buildings",
|
29 |
+
"Cultivation",
|
30 |
+
"Natural green",
|
31 |
+
"Wetland",
|
32 |
+
"Water",
|
33 |
+
"Infrastructure",
|
34 |
+
"Background",
|
35 |
+
)
|
36 |
+
bounds = list(np.arange(len(class_names) + 1) + 1)
|
37 |
+
cmap = mpl.colors.ListedColormap(colors)
|
38 |
+
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
|
39 |
+
|
40 |
+
|
41 |
+
def retouch_label(pred_label, true_label):
|
42 |
+
retouched_label = pred_label + 0
|
43 |
+
blobs = measure.label(retouched_label)
|
44 |
+
for idx in np.unique(blobs):
|
45 |
+
# most frequent label class in this blob
|
46 |
+
retouched_label[blobs == idx] = statsmode(true_label[blobs == idx])[0][0]
|
47 |
+
return retouched_label
|
48 |
+
|
49 |
+
|
50 |
+
def get_class_labels(dataset_name):
|
51 |
+
if dataset_name.startswith("cityscapes"):
|
52 |
+
return [
|
53 |
+
"road",
|
54 |
+
"sidewalk",
|
55 |
+
"parking",
|
56 |
+
"rail track",
|
57 |
+
"building",
|
58 |
+
"wall",
|
59 |
+
"fence",
|
60 |
+
"guard rail",
|
61 |
+
"bridge",
|
62 |
+
"tunnel",
|
63 |
+
"pole",
|
64 |
+
"polegroup",
|
65 |
+
"traffic light",
|
66 |
+
"traffic sign",
|
67 |
+
"vegetation",
|
68 |
+
"terrain",
|
69 |
+
"sky",
|
70 |
+
"person",
|
71 |
+
"rider",
|
72 |
+
"car",
|
73 |
+
"truck",
|
74 |
+
"bus",
|
75 |
+
"caravan",
|
76 |
+
"trailer",
|
77 |
+
"train",
|
78 |
+
"motorcycle",
|
79 |
+
"bicycle",
|
80 |
+
]
|
81 |
+
elif dataset_name == "cocostuff27":
|
82 |
+
return [
|
83 |
+
"electronic",
|
84 |
+
"appliance",
|
85 |
+
"food",
|
86 |
+
"furniture",
|
87 |
+
"indoor",
|
88 |
+
"kitchen",
|
89 |
+
"accessory",
|
90 |
+
"animal",
|
91 |
+
"outdoor",
|
92 |
+
"person",
|
93 |
+
"sports",
|
94 |
+
"vehicle",
|
95 |
+
"ceiling",
|
96 |
+
"floor",
|
97 |
+
"food",
|
98 |
+
"furniture",
|
99 |
+
"rawmaterial",
|
100 |
+
"textile",
|
101 |
+
"wall",
|
102 |
+
"window",
|
103 |
+
"building",
|
104 |
+
"ground",
|
105 |
+
"plant",
|
106 |
+
"sky",
|
107 |
+
"solid",
|
108 |
+
"structural",
|
109 |
+
"water",
|
110 |
+
]
|
111 |
+
elif dataset_name == "voc":
|
112 |
+
return [
|
113 |
+
"background",
|
114 |
+
"aeroplane",
|
115 |
+
"bicycle",
|
116 |
+
"bird",
|
117 |
+
"boat",
|
118 |
+
"bottle",
|
119 |
+
"bus",
|
120 |
+
"car",
|
121 |
+
"cat",
|
122 |
+
"chair",
|
123 |
+
"cow",
|
124 |
+
"diningtable",
|
125 |
+
"dog",
|
126 |
+
"horse",
|
127 |
+
"motorbike",
|
128 |
+
"person",
|
129 |
+
"pottedplant",
|
130 |
+
"sheep",
|
131 |
+
"sofa",
|
132 |
+
"train",
|
133 |
+
"tvmonitor",
|
134 |
+
]
|
135 |
+
elif dataset_name == "potsdam":
|
136 |
+
return ["roads and cars", "buildings and clutter", "trees and vegetation"]
|
137 |
+
else:
|
138 |
+
raise ValueError("Unknown Dataset {}".format(dataset_name))
|
139 |
+
|
140 |
+
|
141 |
+
@hydra.main(config_path="configs", config_name="train_config.yml")
|
142 |
+
def my_app(cfg: DictConfig) -> None:
|
143 |
+
OmegaConf.set_struct(cfg, False)
|
144 |
+
print(OmegaConf.to_yaml(cfg))
|
145 |
+
pytorch_data_dir = cfg.pytorch_data_dir
|
146 |
+
data_dir = join(cfg.output_root, "data")
|
147 |
+
log_dir = join(cfg.output_root, "logs")
|
148 |
+
checkpoint_dir = join(cfg.output_root, "checkpoints")
|
149 |
+
|
150 |
+
prefix = "{}/{}_{}".format(cfg.log_dir, cfg.dataset_name, cfg.experiment_name)
|
151 |
+
name = "{}_date_{}".format(prefix, datetime.now().strftime("%b%d_%H-%M-%S"))
|
152 |
+
cfg.full_name = prefix
|
153 |
+
|
154 |
+
os.makedirs(data_dir, exist_ok=True)
|
155 |
+
os.makedirs(log_dir, exist_ok=True)
|
156 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
157 |
+
|
158 |
+
seed_everything(seed=0)
|
159 |
+
|
160 |
+
print(data_dir)
|
161 |
+
print(cfg.output_root)
|
162 |
+
|
163 |
+
geometric_transforms = T.Compose(
|
164 |
+
[T.RandomHorizontalFlip(), T.RandomResizedCrop(size=cfg.res, scale=(0.8, 1.0))]
|
165 |
+
)
|
166 |
+
photometric_transforms = T.Compose(
|
167 |
+
[
|
168 |
+
T.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
|
169 |
+
T.RandomGrayscale(0.2),
|
170 |
+
T.RandomApply([T.GaussianBlur((5, 5))]),
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
sys.stdout.flush()
|
175 |
+
|
176 |
+
train_dataset = ContrastiveSegDataset(
|
177 |
+
pytorch_data_dir=pytorch_data_dir,
|
178 |
+
dataset_name=cfg.dataset_name,
|
179 |
+
crop_type=cfg.crop_type,
|
180 |
+
image_set="train",
|
181 |
+
transform=get_transform(cfg.res, False, cfg.loader_crop_type),
|
182 |
+
target_transform=get_transform(cfg.res, True, cfg.loader_crop_type),
|
183 |
+
cfg=cfg,
|
184 |
+
aug_geometric_transform=geometric_transforms,
|
185 |
+
aug_photometric_transform=photometric_transforms,
|
186 |
+
num_neighbors=cfg.num_neighbors,
|
187 |
+
mask=True,
|
188 |
+
pos_images=True,
|
189 |
+
pos_labels=True,
|
190 |
+
)
|
191 |
+
|
192 |
+
if cfg.dataset_name == "voc":
|
193 |
+
val_loader_crop = None
|
194 |
+
else:
|
195 |
+
val_loader_crop = "center"
|
196 |
+
|
197 |
+
val_dataset = ContrastiveSegDataset(
|
198 |
+
pytorch_data_dir=pytorch_data_dir,
|
199 |
+
dataset_name=cfg.dataset_name,
|
200 |
+
crop_type=None,
|
201 |
+
image_set="val",
|
202 |
+
transform=get_transform(320, False, val_loader_crop),
|
203 |
+
target_transform=get_transform(320, True, val_loader_crop),
|
204 |
+
mask=True,
|
205 |
+
cfg=cfg,
|
206 |
+
)
|
207 |
+
|
208 |
+
# val_dataset = MaterializedDataset(val_dataset)
|
209 |
+
train_loader = DataLoader(
|
210 |
+
train_dataset,
|
211 |
+
cfg.batch_size,
|
212 |
+
shuffle=True,
|
213 |
+
num_workers=cfg.num_workers,
|
214 |
+
pin_memory=True,
|
215 |
+
)
|
216 |
+
|
217 |
+
if cfg.submitting_to_aml:
|
218 |
+
val_batch_size = 16
|
219 |
+
else:
|
220 |
+
val_batch_size = cfg.batch_size
|
221 |
+
|
222 |
+
val_loader = DataLoader(
|
223 |
+
val_dataset,
|
224 |
+
val_batch_size,
|
225 |
+
shuffle=False,
|
226 |
+
num_workers=cfg.num_workers,
|
227 |
+
pin_memory=True,
|
228 |
+
)
|
229 |
+
|
230 |
+
model = LitUnsupervisedSegmenter(train_dataset.n_classes, cfg)
|
231 |
+
|
232 |
+
tb_logger = TensorBoardLogger(join(log_dir, name), default_hp_metric=False)
|
233 |
+
|
234 |
+
if cfg.submitting_to_aml:
|
235 |
+
gpu_args = dict(gpus=1, val_check_interval=250)
|
236 |
+
|
237 |
+
if gpu_args["val_check_interval"] > len(train_loader):
|
238 |
+
gpu_args.pop("val_check_interval")
|
239 |
+
|
240 |
+
else:
|
241 |
+
gpu_args = dict(gpus=-1, accelerator="ddp", val_check_interval=cfg.val_freq)
|
242 |
+
# gpu_args = dict(gpus=1, accelerator='ddp', val_check_interval=cfg.val_freq)
|
243 |
+
|
244 |
+
if gpu_args["val_check_interval"] > len(train_loader) // 4:
|
245 |
+
gpu_args.pop("val_check_interval")
|
246 |
+
|
247 |
+
trainer = Trainer(
|
248 |
+
log_every_n_steps=cfg.scalar_log_freq,
|
249 |
+
logger=tb_logger,
|
250 |
+
max_steps=cfg.max_steps,
|
251 |
+
callbacks=[
|
252 |
+
ModelCheckpoint(
|
253 |
+
dirpath=join(checkpoint_dir, name),
|
254 |
+
every_n_train_steps=400,
|
255 |
+
save_top_k=2,
|
256 |
+
monitor="test/cluster/mIoU",
|
257 |
+
mode="max",
|
258 |
+
)
|
259 |
+
],
|
260 |
+
**gpu_args
|
261 |
+
)
|
262 |
+
trainer.fit(model, train_loader, val_loader)
|
263 |
+
|
264 |
+
|
265 |
+
if __name__ == "__main__":
|
266 |
+
prep_args()
|
267 |
+
my_app()
|
biomap/unet.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import torch.nn as nn
|
8 |
+
from collections import defaultdict
|
9 |
+
import torchvision
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data.sampler import Sampler
|
12 |
+
|
13 |
+
class Block(nn.Module):
|
14 |
+
def __init__(self, in_ch, out_ch, padding='same'):
|
15 |
+
super().__init__()
|
16 |
+
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=padding)
|
17 |
+
self.relu = nn.ReLU()
|
18 |
+
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=padding)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return self.conv2(self.relu(self.conv1(x)))
|
22 |
+
|
23 |
+
|
24 |
+
class Encoder(nn.Module):
|
25 |
+
def __init__(self, chs=(3,32,64,128,256)):
|
26 |
+
super().__init__()
|
27 |
+
self.enc_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
28 |
+
self.pool = nn.MaxPool2d(2)
|
29 |
+
|
30 |
+
def forward(self, x):
|
31 |
+
ftrs = []
|
32 |
+
for block in self.enc_blocks:
|
33 |
+
x = block(x)
|
34 |
+
ftrs.append(x)
|
35 |
+
x = self.pool(x)
|
36 |
+
return ftrs
|
37 |
+
|
38 |
+
|
39 |
+
class Decoder(nn.Module):
|
40 |
+
def __init__(self, chs=(256,128, 64, 32), aux_ch=70):
|
41 |
+
super().__init__()
|
42 |
+
upchs = tuple([chs[i]+aux_ch if i == 0 else chs[i] for i in range(len(chs))])
|
43 |
+
self.chs = chs
|
44 |
+
self.upchs = upchs
|
45 |
+
self.upconvs = nn.ModuleList([nn.ConvTranspose2d(upchs[i], upchs[i+1], 2, 2) for i in range(len(upchs)-1)])
|
46 |
+
self.dec_blocks = nn.ModuleList([Block(chs[i], chs[i+1]) for i in range(len(chs)-1)])
|
47 |
+
|
48 |
+
def forward(self, x, encoder_features):
|
49 |
+
for i in range(len(self.chs)-1):
|
50 |
+
# pdb.set_trace()
|
51 |
+
x = self.upconvs[i](x)
|
52 |
+
enc_ftrs = self.crop(encoder_features[i], x)
|
53 |
+
x = torch.cat([x, enc_ftrs], dim=1)
|
54 |
+
x = self.dec_blocks[i](x)
|
55 |
+
return x
|
56 |
+
|
57 |
+
def crop(self, enc_ftrs, x):
|
58 |
+
_, _, H, W = x.shape
|
59 |
+
enc_ftrs = torchvision.transforms.CenterCrop([H, W])(enc_ftrs)
|
60 |
+
return enc_ftrs
|
61 |
+
|
62 |
+
|
63 |
+
class AuxUNet(nn.Module):
|
64 |
+
# UNet with auxiliary feature at the bottom
|
65 |
+
def __init__(self, enc_chs=(3,32,64,128,256), dec_chs=(256,128, 64, 32), aux_ch=70, num_class=7, retain_dim=False, out_sz=(224,224)):
|
66 |
+
super().__init__()
|
67 |
+
self.encoder = Encoder(enc_chs)
|
68 |
+
self.decoder = Decoder(dec_chs, aux_ch)
|
69 |
+
self.head = nn.Conv2d(dec_chs[-1], num_class, 1)
|
70 |
+
self.retain_dim = retain_dim
|
71 |
+
|
72 |
+
def forward(self, x, aux):
|
73 |
+
# aux: auxiliary feature at the bottom
|
74 |
+
enc_ftrs = self.encoder(x)
|
75 |
+
enc_ftrs[-1] = torch.cat((enc_ftrs[-1], aux), 1)
|
76 |
+
out = self.decoder(enc_ftrs[::-1][0], enc_ftrs[::-1][1:])
|
77 |
+
out = self.head(out)
|
78 |
+
if self.retain_dim:
|
79 |
+
out = F.interpolate(out, out_sz)
|
80 |
+
return out
|
biomap/utils.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import os
|
3 |
+
from os.path import join
|
4 |
+
import io
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import torch.multiprocessing
|
9 |
+
import torch.nn as nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import wget
|
12 |
+
from PIL import Image
|
13 |
+
from scipy.optimize import linear_sum_assignment
|
14 |
+
from torch._six import string_classes
|
15 |
+
from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format
|
16 |
+
from torchmetrics import Metric
|
17 |
+
from torchvision import models
|
18 |
+
from torchvision import transforms as T
|
19 |
+
from torch.utils.tensorboard.summary import hparams
|
20 |
+
import matplotlib as mpl
|
21 |
+
torch.multiprocessing.set_sharing_strategy("file_system")
|
22 |
+
colors = ("red", "palegreen", "green", "steelblue", "blue", "yellow", "lightgrey")
|
23 |
+
class_names = (
|
24 |
+
"Buildings",
|
25 |
+
"Cultivation",
|
26 |
+
"Natural green",
|
27 |
+
"Wetland",
|
28 |
+
"Water",
|
29 |
+
"Infrastructure",
|
30 |
+
"Background",
|
31 |
+
)
|
32 |
+
bounds = list(np.arange(len(class_names) + 1) + 1)
|
33 |
+
cmap = mpl.colors.ListedColormap(colors)
|
34 |
+
norm = mpl.colors.BoundaryNorm(bounds, cmap.N)
|
35 |
+
|
36 |
+
def compute_biodiv_score(image):
|
37 |
+
"""Compute the biodiversity score of an image
|
38 |
+
|
39 |
+
Args:
|
40 |
+
image (_type_): _description_
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
biodiversity_score: the biodiversity score associated to the landscape of the image
|
44 |
+
"""
|
45 |
+
pix = np.array(image.getdata())
|
46 |
+
return np.mean(pix)
|
47 |
+
|
48 |
+
import cv2
|
49 |
+
def create_video(array_images, output_path="output.mp4"):
|
50 |
+
height, width, layers = array_images[0].shape
|
51 |
+
size = (width,height)
|
52 |
+
|
53 |
+
fourcc = cv2.VideoWriter_fourcc(*'VP90')
|
54 |
+
out = cv2.VideoWriter('output.mp4', fourcc, 2, size)
|
55 |
+
|
56 |
+
for i in range(len(array_images)):
|
57 |
+
out.write(array_images[i])
|
58 |
+
out.release()
|
59 |
+
return out
|
60 |
+
|
61 |
+
|
62 |
+
|
63 |
+
def transform_to_pil(outputs, alpha=0.3):
|
64 |
+
"""Turn an ouput into a PIL
|
65 |
+
|
66 |
+
Args:
|
67 |
+
outputs (_type_): _description_
|
68 |
+
alpha (float, optional): _description_. Defaults to 0.3.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
_type_: _description_
|
72 |
+
"""
|
73 |
+
|
74 |
+
# Transform img with torch
|
75 |
+
img = torch.moveaxis(prep_for_plot(outputs["img"][0]), -1, 0)
|
76 |
+
img = T.ToPILImage()(img)
|
77 |
+
# Transform label by saving it then open it
|
78 |
+
label = outputs["linear_preds"][0].numpy()
|
79 |
+
# image_label = Image.fromarray(label, mode="P")
|
80 |
+
plt.imsave("output/label.png", label, cmap=cmap)
|
81 |
+
image_label = Image.open("output/label.png")
|
82 |
+
# Overlay labels with img wit alpha
|
83 |
+
background = img.convert("RGBA")
|
84 |
+
overlay = image_label.convert("RGBA")
|
85 |
+
labeled_img = Image.blend(background, overlay, alpha)
|
86 |
+
labeled_img = labeled_img.convert("RGB")
|
87 |
+
return img, image_label, labeled_img
|
88 |
+
|
89 |
+
|
90 |
+
def prep_for_plot(img, rescale=True, resize=None):
|
91 |
+
if resize is not None:
|
92 |
+
img = F.interpolate(img.unsqueeze(0), resize, mode="bilinear")
|
93 |
+
else:
|
94 |
+
img = img.unsqueeze(0)
|
95 |
+
|
96 |
+
plot_img = unnorm(img).squeeze(0).cpu().permute(1, 2, 0)
|
97 |
+
if rescale:
|
98 |
+
plot_img = (plot_img - plot_img.min()) / (plot_img.max() - plot_img.min())
|
99 |
+
return plot_img
|
100 |
+
|
101 |
+
|
102 |
+
def add_plot(writer, name, step):
|
103 |
+
buf = io.BytesIO()
|
104 |
+
plt.savefig(buf, format='jpeg', dpi=100)
|
105 |
+
buf.seek(0)
|
106 |
+
image = Image.open(buf)
|
107 |
+
image = T.ToTensor()(image)
|
108 |
+
writer.add_image(name, image, step)
|
109 |
+
plt.clf()
|
110 |
+
plt.close()
|
111 |
+
|
112 |
+
|
113 |
+
@torch.jit.script
|
114 |
+
def shuffle(x):
|
115 |
+
return x[torch.randperm(x.shape[0])]
|
116 |
+
|
117 |
+
|
118 |
+
def add_hparams_fixed(writer, hparam_dict, metric_dict, global_step):
|
119 |
+
exp, ssi, sei = hparams(hparam_dict, metric_dict)
|
120 |
+
writer.file_writer.add_summary(exp)
|
121 |
+
writer.file_writer.add_summary(ssi)
|
122 |
+
writer.file_writer.add_summary(sei)
|
123 |
+
for k, v in metric_dict.items():
|
124 |
+
writer.add_scalar(k, v, global_step)
|
125 |
+
|
126 |
+
|
127 |
+
@torch.jit.script
|
128 |
+
def resize(classes: torch.Tensor, size: int):
|
129 |
+
return F.interpolate(classes, (size, size), mode="bilinear", align_corners=False)
|
130 |
+
|
131 |
+
|
132 |
+
def one_hot_feats(labels, n_classes):
|
133 |
+
return F.one_hot(labels, n_classes).permute(0, 3, 1, 2).to(torch.float32)
|
134 |
+
|
135 |
+
|
136 |
+
def load_model(model_type, data_dir):
|
137 |
+
if model_type == "robust_resnet50":
|
138 |
+
model = models.resnet50(pretrained=False)
|
139 |
+
model_file = join(data_dir, 'imagenet_l2_3_0.pt')
|
140 |
+
if not os.path.exists(model_file):
|
141 |
+
wget.download("http://6.869.csail.mit.edu/fa19/psets19/pset6/imagenet_l2_3_0.pt",
|
142 |
+
model_file)
|
143 |
+
model_weights = torch.load(model_file)
|
144 |
+
model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
|
145 |
+
'model' in name}
|
146 |
+
model.load_state_dict(model_weights_modified)
|
147 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
148 |
+
elif model_type == "densecl":
|
149 |
+
model = models.resnet50(pretrained=False)
|
150 |
+
model_file = join(data_dir, 'densecl_r50_coco_1600ep.pth')
|
151 |
+
if not os.path.exists(model_file):
|
152 |
+
wget.download("https://cloudstor.aarnet.edu.au/plus/s/3GapXiWuVAzdKwJ/download",
|
153 |
+
model_file)
|
154 |
+
model_weights = torch.load(model_file)
|
155 |
+
# model_weights_modified = {name.split('model.')[1]: value for name, value in model_weights['model'].items() if
|
156 |
+
# 'model' in name}
|
157 |
+
model.load_state_dict(model_weights['state_dict'], strict=False)
|
158 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
159 |
+
elif model_type == "resnet50":
|
160 |
+
model = models.resnet50(pretrained=True)
|
161 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
162 |
+
elif model_type == "mocov2":
|
163 |
+
model = models.resnet50(pretrained=False)
|
164 |
+
model_file = join(data_dir, 'moco_v2_800ep_pretrain.pth.tar')
|
165 |
+
if not os.path.exists(model_file):
|
166 |
+
wget.download("https://dl.fbaipublicfiles.com/moco/moco_checkpoints/"
|
167 |
+
"moco_v2_800ep/moco_v2_800ep_pretrain.pth.tar", model_file)
|
168 |
+
checkpoint = torch.load(model_file)
|
169 |
+
# rename moco pre-trained keys
|
170 |
+
state_dict = checkpoint['state_dict']
|
171 |
+
for k in list(state_dict.keys()):
|
172 |
+
# retain only encoder_q up to before the embedding layer
|
173 |
+
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
|
174 |
+
# remove prefix
|
175 |
+
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
|
176 |
+
# delete renamed or unused k
|
177 |
+
del state_dict[k]
|
178 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
179 |
+
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
|
180 |
+
model = nn.Sequential(*list(model.children())[:-1])
|
181 |
+
elif model_type == "densenet121":
|
182 |
+
model = models.densenet121(pretrained=True)
|
183 |
+
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
|
184 |
+
elif model_type == "vgg11":
|
185 |
+
model = models.vgg11(pretrained=True)
|
186 |
+
model = nn.Sequential(*list(model.children())[:-1] + [nn.AdaptiveAvgPool2d((1, 1))])
|
187 |
+
else:
|
188 |
+
raise ValueError("No model: {} found".format(model_type))
|
189 |
+
|
190 |
+
model.eval()
|
191 |
+
model.cuda()
|
192 |
+
return model
|
193 |
+
|
194 |
+
|
195 |
+
class UnNormalize(object):
|
196 |
+
def __init__(self, mean, std):
|
197 |
+
self.mean = mean
|
198 |
+
self.std = std
|
199 |
+
|
200 |
+
def __call__(self, image):
|
201 |
+
image2 = torch.clone(image)
|
202 |
+
for t, m, s in zip(image2, self.mean, self.std):
|
203 |
+
t.mul_(s).add_(m)
|
204 |
+
return image2
|
205 |
+
|
206 |
+
|
207 |
+
normalize = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
208 |
+
unnorm = UnNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
209 |
+
|
210 |
+
|
211 |
+
class ToTargetTensor(object):
|
212 |
+
def __call__(self, target):
|
213 |
+
return torch.as_tensor(np.array(target), dtype=torch.int64).unsqueeze(0)
|
214 |
+
|
215 |
+
|
216 |
+
def prep_args():
|
217 |
+
import sys
|
218 |
+
|
219 |
+
old_args = sys.argv
|
220 |
+
new_args = [old_args.pop(0)]
|
221 |
+
while len(old_args) > 0:
|
222 |
+
arg = old_args.pop(0)
|
223 |
+
if len(arg.split("=")) == 2:
|
224 |
+
new_args.append(arg)
|
225 |
+
elif arg.startswith("--"):
|
226 |
+
new_args.append(arg[2:] + "=" + old_args.pop(0))
|
227 |
+
else:
|
228 |
+
raise ValueError("Unexpected arg style {}".format(arg))
|
229 |
+
sys.argv = new_args
|
230 |
+
|
231 |
+
|
232 |
+
def get_transform(res, is_label, crop_type):
|
233 |
+
if crop_type == "center":
|
234 |
+
cropper = T.CenterCrop(res)
|
235 |
+
elif crop_type == "random":
|
236 |
+
cropper = T.RandomCrop(res)
|
237 |
+
elif crop_type is None:
|
238 |
+
cropper = T.Lambda(lambda x: x)
|
239 |
+
res = (res, res)
|
240 |
+
else:
|
241 |
+
raise ValueError("Unknown Cropper {}".format(crop_type))
|
242 |
+
if is_label:
|
243 |
+
return T.Compose([T.Resize(res, Image.NEAREST),
|
244 |
+
cropper,
|
245 |
+
ToTargetTensor()])
|
246 |
+
else:
|
247 |
+
return T.Compose([T.Resize(res, Image.NEAREST),
|
248 |
+
cropper,
|
249 |
+
T.ToTensor(),
|
250 |
+
normalize])
|
251 |
+
|
252 |
+
|
253 |
+
def _remove_axes(ax):
|
254 |
+
ax.xaxis.set_major_formatter(plt.NullFormatter())
|
255 |
+
ax.yaxis.set_major_formatter(plt.NullFormatter())
|
256 |
+
ax.set_xticks([])
|
257 |
+
ax.set_yticks([])
|
258 |
+
|
259 |
+
|
260 |
+
def remove_axes(axes):
|
261 |
+
if len(axes.shape) == 2:
|
262 |
+
for ax1 in axes:
|
263 |
+
for ax in ax1:
|
264 |
+
_remove_axes(ax)
|
265 |
+
else:
|
266 |
+
for ax in axes:
|
267 |
+
_remove_axes(ax)
|
268 |
+
|
269 |
+
|
270 |
+
class UnsupervisedMetrics(Metric):
|
271 |
+
def __init__(self, prefix: str, n_classes: int, extra_clusters: int, compute_hungarian: bool,
|
272 |
+
dist_sync_on_step=True):
|
273 |
+
# call `self.add_state`for every internal state that is needed for the metrics computations
|
274 |
+
# dist_reduce_fx indicates the function that should be used to reduce
|
275 |
+
# state from multiple processes
|
276 |
+
super().__init__(dist_sync_on_step=dist_sync_on_step)
|
277 |
+
|
278 |
+
self.n_classes = n_classes
|
279 |
+
self.extra_clusters = extra_clusters
|
280 |
+
self.compute_hungarian = compute_hungarian
|
281 |
+
self.prefix = prefix
|
282 |
+
self.add_state("stats",
|
283 |
+
default=torch.zeros(n_classes + self.extra_clusters, n_classes, dtype=torch.int64),
|
284 |
+
dist_reduce_fx="sum")
|
285 |
+
|
286 |
+
def update(self, preds: torch.Tensor, target: torch.Tensor):
|
287 |
+
with torch.no_grad():
|
288 |
+
actual = target.reshape(-1)
|
289 |
+
preds = preds.reshape(-1)
|
290 |
+
mask = (actual >= 0) & (actual < self.n_classes) & (preds >= 0) & (preds < self.n_classes)
|
291 |
+
actual = actual[mask]
|
292 |
+
preds = preds[mask]
|
293 |
+
self.stats += torch.bincount(
|
294 |
+
(self.n_classes + self.extra_clusters) * actual + preds,
|
295 |
+
minlength=self.n_classes * (self.n_classes + self.extra_clusters)) \
|
296 |
+
.reshape(self.n_classes, self.n_classes + self.extra_clusters).t().to(self.stats.device)
|
297 |
+
|
298 |
+
def map_clusters(self, clusters):
|
299 |
+
if self.extra_clusters == 0:
|
300 |
+
return torch.tensor(self.assignments[1])[clusters]
|
301 |
+
else:
|
302 |
+
missing = sorted(list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0])))
|
303 |
+
cluster_to_class = self.assignments[1]
|
304 |
+
for missing_entry in missing:
|
305 |
+
if missing_entry == cluster_to_class.shape[0]:
|
306 |
+
cluster_to_class = np.append(cluster_to_class, -1)
|
307 |
+
else:
|
308 |
+
cluster_to_class = np.insert(cluster_to_class, missing_entry + 1, -1)
|
309 |
+
cluster_to_class = torch.tensor(cluster_to_class)
|
310 |
+
return cluster_to_class[clusters]
|
311 |
+
|
312 |
+
def compute(self):
|
313 |
+
if self.compute_hungarian:
|
314 |
+
self.assignments = linear_sum_assignment(self.stats.detach().cpu(), maximize=True)
|
315 |
+
# print(self.assignments)
|
316 |
+
if self.extra_clusters == 0:
|
317 |
+
self.histogram = self.stats[np.argsort(self.assignments[1]), :]
|
318 |
+
if self.extra_clusters > 0:
|
319 |
+
self.assignments_t = linear_sum_assignment(self.stats.detach().cpu().t(), maximize=True)
|
320 |
+
histogram = self.stats[self.assignments_t[1], :]
|
321 |
+
missing = list(set(range(self.n_classes + self.extra_clusters)) - set(self.assignments[0]))
|
322 |
+
new_row = self.stats[missing, :].sum(0, keepdim=True)
|
323 |
+
histogram = torch.cat([histogram, new_row], axis=0)
|
324 |
+
new_col = torch.zeros(self.n_classes + 1, 1, device=histogram.device)
|
325 |
+
self.histogram = torch.cat([histogram, new_col], axis=1)
|
326 |
+
else:
|
327 |
+
self.assignments = (torch.arange(self.n_classes).unsqueeze(1),
|
328 |
+
torch.arange(self.n_classes).unsqueeze(1))
|
329 |
+
self.histogram = self.stats
|
330 |
+
|
331 |
+
tp = torch.diag(self.histogram)
|
332 |
+
fp = torch.sum(self.histogram, dim=0) - tp
|
333 |
+
fn = torch.sum(self.histogram, dim=1) - tp
|
334 |
+
|
335 |
+
iou = tp / (tp + fp + fn)
|
336 |
+
prc = tp / (tp + fn)
|
337 |
+
opc = torch.sum(tp) / torch.sum(self.histogram)
|
338 |
+
|
339 |
+
metric_dict = {self.prefix + "mIoU": iou[~torch.isnan(iou)].mean().item(),
|
340 |
+
self.prefix + "Accuracy": opc.item()}
|
341 |
+
return {k: 100 * v for k, v in metric_dict.items()}
|
342 |
+
|
343 |
+
|
344 |
+
def flexible_collate(batch):
|
345 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
346 |
+
|
347 |
+
elem = batch[0]
|
348 |
+
elem_type = type(elem)
|
349 |
+
if isinstance(elem, torch.Tensor):
|
350 |
+
out = None
|
351 |
+
if torch.utils.data.get_worker_info() is not None:
|
352 |
+
# If we're in a background process, concatenate directly into a
|
353 |
+
# shared memory tensor to avoid an extra copy
|
354 |
+
numel = sum([x.numel() for x in batch])
|
355 |
+
storage = elem.storage()._new_shared(numel)
|
356 |
+
out = elem.new(storage)
|
357 |
+
try:
|
358 |
+
return torch.stack(batch, 0, out=out)
|
359 |
+
except RuntimeError:
|
360 |
+
return batch
|
361 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
362 |
+
and elem_type.__name__ != 'string_':
|
363 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
364 |
+
# array of string classes and object
|
365 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
366 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
367 |
+
|
368 |
+
return flexible_collate([torch.as_tensor(b) for b in batch])
|
369 |
+
elif elem.shape == (): # scalars
|
370 |
+
return torch.as_tensor(batch)
|
371 |
+
elif isinstance(elem, float):
|
372 |
+
return torch.tensor(batch, dtype=torch.float64)
|
373 |
+
elif isinstance(elem, int):
|
374 |
+
return torch.tensor(batch)
|
375 |
+
elif isinstance(elem, string_classes):
|
376 |
+
return batch
|
377 |
+
elif isinstance(elem, collections.abc.Mapping):
|
378 |
+
return {key: flexible_collate([d[key] for d in batch]) for key in elem}
|
379 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
380 |
+
return elem_type(*(flexible_collate(samples) for samples in zip(*batch)))
|
381 |
+
elif isinstance(elem, collections.abc.Sequence):
|
382 |
+
# check to make sure that the elements in batch have consistent size
|
383 |
+
it = iter(batch)
|
384 |
+
elem_size = len(next(it))
|
385 |
+
if not all(len(elem) == elem_size for elem in it):
|
386 |
+
raise RuntimeError('each element in list of batch should be of equal size')
|
387 |
+
transposed = zip(*batch)
|
388 |
+
return [flexible_collate(samples) for samples in transposed]
|
389 |
+
|
390 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
biomap/utils_gee.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
import requests
|
3 |
+
import ee
|
4 |
+
import numpy as np
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
|
7 |
+
#Initialize
|
8 |
+
service_account = '[email protected]'
|
9 |
+
credentials = ee.ServiceAccountCredentials(service_account, '.private-key.json')
|
10 |
+
ee.Initialize(credentials)
|
11 |
+
|
12 |
+
#delete clouds
|
13 |
+
def maskS2clouds(image):
|
14 |
+
qa = image.select('QA60');
|
15 |
+
|
16 |
+
# // Bits 10 and 11 are clouds and cirrus, respectively.
|
17 |
+
cloudBitMask = 1 << 10;
|
18 |
+
cirrusBitMask = 1 << 11;
|
19 |
+
|
20 |
+
# // Both flags should be set to zero, indicating clear conditions.
|
21 |
+
mask = (qa.bitwiseAnd(cloudBitMask).eq(0))and(qa.bitwiseAnd(cirrusBitMask).eq(0))
|
22 |
+
|
23 |
+
return image.updateMask(mask).divide(10000);
|
24 |
+
|
25 |
+
|
26 |
+
#find ee_img
|
27 |
+
def extract_ee_img(location,start_date,end_date, width = 0.01 , len = 0.01) :
|
28 |
+
"""Extract the earth engine image
|
29 |
+
|
30 |
+
Args:
|
31 |
+
location (list[float]):
|
32 |
+
start_date (str): the start date for finding an image
|
33 |
+
end_date (str): the end date for finding an image
|
34 |
+
width (float, optional): _description_. Defaults to 0.01.
|
35 |
+
len (float, optional): _description_. Defaults to 0.01.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
_type_: _description_
|
39 |
+
"""
|
40 |
+
# define the polygone
|
41 |
+
polygone =[[[float(location[0])-0.01,float(location[1])+0.01],
|
42 |
+
[float(location[0])-0.01,float(location[1])-0.01],
|
43 |
+
[float(location[0])+0.01,float(location[1])-0.01],
|
44 |
+
[float(location[0])+0.01,float(location[1])+0.01],
|
45 |
+
]]
|
46 |
+
|
47 |
+
#define the ee geometry
|
48 |
+
geometry = ee.Geometry.Polygon(polygone, None, False);
|
49 |
+
|
50 |
+
#extract the dataset
|
51 |
+
dataset = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')\
|
52 |
+
.filterDate(start_date, end_date)\
|
53 |
+
.filter(ee.Filter.lt('CLOUDY_PIXEL_PERCENTAGE',1))\
|
54 |
+
.map(maskS2clouds)
|
55 |
+
return dataset.mean(), geometry
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
# Get URL
|
60 |
+
def get_url(ee_img, geometry, scale=5):
|
61 |
+
"""Get the url of a dataset and a geometry
|
62 |
+
|
63 |
+
Args:
|
64 |
+
ee_img (ee.ImageCollection: meta data on the image
|
65 |
+
geometry (ee.Geometry.Polygon): geometry of the desired landscape
|
66 |
+
scale (int, optional): _description_. Defaults to 5.
|
67 |
+
|
68 |
+
Returns:
|
69 |
+
str: the url to use to ask the server
|
70 |
+
"""
|
71 |
+
region = geometry
|
72 |
+
|
73 |
+
# collectionList = ee_img.toList(ee_img.size())
|
74 |
+
# collectionSize = collectionList.size().getInfo()
|
75 |
+
# for i in xrange(collectionSize):
|
76 |
+
# ee.batch.Export.image.toDrive(
|
77 |
+
# image = ee.Image(collectionList.get(i)).clip(rectangle),
|
78 |
+
# fileNamePrefix = 'foo' + str(i + 1),
|
79 |
+
# dimensions = '128x128').start()
|
80 |
+
|
81 |
+
url = ee_img.getDownloadURL({
|
82 |
+
# 'min': 0.0,
|
83 |
+
# 'max': 0.3,
|
84 |
+
'bands': ['B4', 'B3', 'B2'],
|
85 |
+
'region' : region,
|
86 |
+
'scale' : scale,
|
87 |
+
'format' : 'NPY'
|
88 |
+
})
|
89 |
+
|
90 |
+
return url
|
91 |
+
|
92 |
+
def extract_np_from_url(url):
|
93 |
+
"""extract a numpy array based on a url
|
94 |
+
|
95 |
+
Args:
|
96 |
+
url (str): _description_
|
97 |
+
|
98 |
+
Returns:
|
99 |
+
numpyarray: response from earth engine as numpy
|
100 |
+
"""
|
101 |
+
#get the response from url
|
102 |
+
response = requests.get(url)
|
103 |
+
|
104 |
+
#transform it into numpy
|
105 |
+
data = np.load(io.BytesIO(response.content))
|
106 |
+
|
107 |
+
#transform numpy of tuples to 3D numpy
|
108 |
+
temp1 = []
|
109 |
+
|
110 |
+
for x in data:
|
111 |
+
temp2 = []
|
112 |
+
for y in x :
|
113 |
+
temp2.append([z for z in y])
|
114 |
+
temp1.append(temp2)
|
115 |
+
|
116 |
+
data = np.array(temp1)
|
117 |
+
|
118 |
+
return data
|
119 |
+
|
120 |
+
#Fonction globale
|
121 |
+
def extract_img(location,start_date,end_date, width = 0.01 , len = 0.01,scale=5):
|
122 |
+
"""Extract an image of the landscape at the selected longitude and latitude with the selected width and length
|
123 |
+
|
124 |
+
Args:
|
125 |
+
location (list[float]): [latitude of the center of the landscape, longitude of the center of the landscape]
|
126 |
+
start_date (str): the start date
|
127 |
+
end_date (str): _description_
|
128 |
+
width (float, optional): _description_. Defaults to 0.01.
|
129 |
+
len (float, optional): _description_. Defaults to 0.01.
|
130 |
+
scale (int, optional): _description_. Defaults to 5.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
img: image as numpy array
|
134 |
+
"""
|
135 |
+
ee_img, geometry = extract_ee_img(location, width,start_date,end_date , len)
|
136 |
+
url = get_url(ee_img, geometry, scale)
|
137 |
+
img = extract_np_from_url(url)
|
138 |
+
|
139 |
+
return img
|
140 |
+
|
141 |
+
# transform img from numpy to PIL
|
142 |
+
def transform_ee_img(img, min = 0, max=0.3):
|
143 |
+
"""Transform an img from numpy to PIL
|
144 |
+
|
145 |
+
Args:
|
146 |
+
img (numpy array): the original image as a numpy array
|
147 |
+
min (int, optional): _description_. Defaults to 0.
|
148 |
+
max (float, optional): _description_. Defaults to 0.3.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
img_test: a PIL image
|
152 |
+
"""
|
153 |
+
img_test=img
|
154 |
+
img_test=np.minimum(img_test*255/max,np.ones(img.shape)*255)
|
155 |
+
img_test=np.uint8((np.rint(img_test)).astype(int))
|
156 |
+
plt.imshow(img_test)
|
157 |
+
return img_test
|
poetry.lock
ADDED
@@ -0,0 +1,1625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[package]]
|
2 |
+
name = "absl-py"
|
3 |
+
version = "1.4.0"
|
4 |
+
description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py."
|
5 |
+
category = "main"
|
6 |
+
optional = false
|
7 |
+
python-versions = ">=3.6"
|
8 |
+
|
9 |
+
[[package]]
|
10 |
+
name = "aiofiles"
|
11 |
+
version = "23.1.0"
|
12 |
+
description = "File support for asyncio."
|
13 |
+
category = "main"
|
14 |
+
optional = false
|
15 |
+
python-versions = ">=3.7,<4.0"
|
16 |
+
|
17 |
+
[[package]]
|
18 |
+
name = "aiohttp"
|
19 |
+
version = "3.8.4"
|
20 |
+
description = "Async http client/server framework (asyncio)"
|
21 |
+
category = "main"
|
22 |
+
optional = false
|
23 |
+
python-versions = ">=3.6"
|
24 |
+
|
25 |
+
[package.dependencies]
|
26 |
+
aiosignal = ">=1.1.2"
|
27 |
+
async-timeout = ">=4.0.0a3,<5.0"
|
28 |
+
attrs = ">=17.3.0"
|
29 |
+
charset-normalizer = ">=2.0,<4.0"
|
30 |
+
frozenlist = ">=1.1.1"
|
31 |
+
multidict = ">=4.5,<7.0"
|
32 |
+
yarl = ">=1.0,<2.0"
|
33 |
+
|
34 |
+
[package.extras]
|
35 |
+
speedups = ["aiodns", "brotli", "cchardet"]
|
36 |
+
|
37 |
+
[[package]]
|
38 |
+
name = "aiosignal"
|
39 |
+
version = "1.3.1"
|
40 |
+
description = "aiosignal: a list of registered asynchronous callbacks"
|
41 |
+
category = "main"
|
42 |
+
optional = false
|
43 |
+
python-versions = ">=3.7"
|
44 |
+
|
45 |
+
[package.dependencies]
|
46 |
+
frozenlist = ">=1.1.0"
|
47 |
+
|
48 |
+
[[package]]
|
49 |
+
name = "altair"
|
50 |
+
version = "4.2.2"
|
51 |
+
description = "Altair: A declarative statistical visualization library for Python."
|
52 |
+
category = "main"
|
53 |
+
optional = false
|
54 |
+
python-versions = ">=3.7"
|
55 |
+
|
56 |
+
[package.dependencies]
|
57 |
+
entrypoints = "*"
|
58 |
+
jinja2 = "*"
|
59 |
+
jsonschema = ">=3.0"
|
60 |
+
numpy = "*"
|
61 |
+
pandas = ">=0.18"
|
62 |
+
toolz = "*"
|
63 |
+
|
64 |
+
[package.extras]
|
65 |
+
dev = ["black", "docutils", "ipython", "flake8", "pytest", "sphinx", "mistune (<2.0.0)", "m2r", "vega-datasets", "recommonmark"]
|
66 |
+
|
67 |
+
[[package]]
|
68 |
+
name = "antlr4-python3-runtime"
|
69 |
+
version = "4.9.3"
|
70 |
+
description = "ANTLR 4.9.3 runtime for Python 3.7"
|
71 |
+
category = "main"
|
72 |
+
optional = false
|
73 |
+
python-versions = "*"
|
74 |
+
|
75 |
+
[[package]]
|
76 |
+
name = "anyio"
|
77 |
+
version = "3.6.2"
|
78 |
+
description = "High level compatibility layer for multiple asynchronous event loop implementations"
|
79 |
+
category = "main"
|
80 |
+
optional = false
|
81 |
+
python-versions = ">=3.6.2"
|
82 |
+
|
83 |
+
[package.dependencies]
|
84 |
+
idna = ">=2.8"
|
85 |
+
sniffio = ">=1.1"
|
86 |
+
|
87 |
+
[package.extras]
|
88 |
+
doc = ["packaging", "sphinx-rtd-theme", "sphinx-autodoc-typehints (>=1.2.0)"]
|
89 |
+
test = ["coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "contextlib2", "uvloop (<0.15)", "mock (>=4)", "uvloop (>=0.15)"]
|
90 |
+
trio = ["trio (>=0.16,<0.22)"]
|
91 |
+
|
92 |
+
[[package]]
|
93 |
+
name = "async-timeout"
|
94 |
+
version = "4.0.2"
|
95 |
+
description = "Timeout context manager for asyncio programs"
|
96 |
+
category = "main"
|
97 |
+
optional = false
|
98 |
+
python-versions = ">=3.6"
|
99 |
+
|
100 |
+
[[package]]
|
101 |
+
name = "attrs"
|
102 |
+
version = "19.3.0"
|
103 |
+
description = "Classes Without Boilerplate"
|
104 |
+
category = "main"
|
105 |
+
optional = false
|
106 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
107 |
+
|
108 |
+
[package.extras]
|
109 |
+
azure-pipelines = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "pytest-azurepipelines"]
|
110 |
+
dev = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface", "sphinx", "pre-commit"]
|
111 |
+
docs = ["sphinx", "zope.interface"]
|
112 |
+
tests = ["coverage", "hypothesis", "pympler", "pytest (>=4.3.0)", "six", "zope.interface"]
|
113 |
+
|
114 |
+
[[package]]
|
115 |
+
name = "cachetools"
|
116 |
+
version = "5.3.0"
|
117 |
+
description = "Extensible memoizing collections and decorators"
|
118 |
+
category = "main"
|
119 |
+
optional = false
|
120 |
+
python-versions = "~=3.7"
|
121 |
+
|
122 |
+
[[package]]
|
123 |
+
name = "certifi"
|
124 |
+
version = "2022.12.7"
|
125 |
+
description = "Python package for providing Mozilla's CA Bundle."
|
126 |
+
category = "main"
|
127 |
+
optional = false
|
128 |
+
python-versions = ">=3.6"
|
129 |
+
|
130 |
+
[[package]]
|
131 |
+
name = "charset-normalizer"
|
132 |
+
version = "3.1.0"
|
133 |
+
description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet."
|
134 |
+
category = "main"
|
135 |
+
optional = false
|
136 |
+
python-versions = ">=3.7.0"
|
137 |
+
|
138 |
+
[[package]]
|
139 |
+
name = "click"
|
140 |
+
version = "8.1.3"
|
141 |
+
description = "Composable command line interface toolkit"
|
142 |
+
category = "main"
|
143 |
+
optional = false
|
144 |
+
python-versions = ">=3.7"
|
145 |
+
|
146 |
+
[package.dependencies]
|
147 |
+
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
148 |
+
|
149 |
+
[[package]]
|
150 |
+
name = "colorama"
|
151 |
+
version = "0.4.6"
|
152 |
+
description = "Cross-platform colored terminal text."
|
153 |
+
category = "main"
|
154 |
+
optional = false
|
155 |
+
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7"
|
156 |
+
|
157 |
+
[[package]]
|
158 |
+
name = "contourpy"
|
159 |
+
version = "1.0.7"
|
160 |
+
description = "Python library for calculating contours of 2D quadrilateral grids"
|
161 |
+
category = "main"
|
162 |
+
optional = false
|
163 |
+
python-versions = ">=3.8"
|
164 |
+
|
165 |
+
[package.dependencies]
|
166 |
+
numpy = ">=1.16"
|
167 |
+
|
168 |
+
[package.extras]
|
169 |
+
bokeh = ["bokeh", "chromedriver", "selenium"]
|
170 |
+
docs = ["furo", "sphinx-copybutton"]
|
171 |
+
mypy = ["contourpy", "docutils-stubs", "mypy (==0.991)", "types-pillow"]
|
172 |
+
test = ["matplotlib", "pillow", "pytest"]
|
173 |
+
test-no-images = ["pytest"]
|
174 |
+
|
175 |
+
[[package]]
|
176 |
+
name = "cycler"
|
177 |
+
version = "0.11.0"
|
178 |
+
description = "Composable style cycles"
|
179 |
+
category = "main"
|
180 |
+
optional = false
|
181 |
+
python-versions = ">=3.6"
|
182 |
+
|
183 |
+
[[package]]
|
184 |
+
name = "earthengine-api"
|
185 |
+
version = "0.1.338"
|
186 |
+
description = "Earth Engine Python API"
|
187 |
+
category = "main"
|
188 |
+
optional = false
|
189 |
+
python-versions = "*"
|
190 |
+
|
191 |
+
[package.dependencies]
|
192 |
+
google-api-python-client = ">=1.12.1"
|
193 |
+
google-auth = ">=1.4.1"
|
194 |
+
google-auth-httplib2 = ">=0.0.3"
|
195 |
+
google-cloud-storage = "*"
|
196 |
+
httplib2 = ">=0.9.2,<1dev"
|
197 |
+
requests = "*"
|
198 |
+
|
199 |
+
[[package]]
|
200 |
+
name = "ee-extra"
|
201 |
+
version = "0.0.15"
|
202 |
+
description = "A ninja Python package behind rgee, rgeeExtra and eemont."
|
203 |
+
category = "main"
|
204 |
+
optional = false
|
205 |
+
python-versions = "*"
|
206 |
+
|
207 |
+
[package.dependencies]
|
208 |
+
earthengine-api = "*"
|
209 |
+
|
210 |
+
[[package]]
|
211 |
+
name = "entrypoints"
|
212 |
+
version = "0.4"
|
213 |
+
description = "Discover and load entry points from installed packages."
|
214 |
+
category = "main"
|
215 |
+
optional = false
|
216 |
+
python-versions = ">=3.6"
|
217 |
+
|
218 |
+
[[package]]
|
219 |
+
name = "fastapi"
|
220 |
+
version = "0.95.1"
|
221 |
+
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
|
222 |
+
category = "main"
|
223 |
+
optional = false
|
224 |
+
python-versions = ">=3.7"
|
225 |
+
|
226 |
+
[package.dependencies]
|
227 |
+
pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0"
|
228 |
+
starlette = ">=0.26.1,<0.27.0"
|
229 |
+
|
230 |
+
[package.extras]
|
231 |
+
all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
|
232 |
+
dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>=0.12.0,<0.21.0)"]
|
233 |
+
doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"]
|
234 |
+
test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"]
|
235 |
+
|
236 |
+
[[package]]
|
237 |
+
name = "ffmpy"
|
238 |
+
version = "0.3.0"
|
239 |
+
description = "A simple Python wrapper for ffmpeg"
|
240 |
+
category = "main"
|
241 |
+
optional = false
|
242 |
+
python-versions = "*"
|
243 |
+
|
244 |
+
[[package]]
|
245 |
+
name = "filelock"
|
246 |
+
version = "3.11.0"
|
247 |
+
description = "A platform independent file lock."
|
248 |
+
category = "main"
|
249 |
+
optional = false
|
250 |
+
python-versions = ">=3.7"
|
251 |
+
|
252 |
+
[package.extras]
|
253 |
+
docs = ["furo (>=2023.3.27)", "sphinx-autodoc-typehints (>=1.22,!=1.23.4)", "sphinx (>=6.1.3)"]
|
254 |
+
testing = ["covdefaults (>=2.3)", "coverage (>=7.2.2)", "diff-cover (>=7.5)", "pytest-cov (>=4)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)", "pytest (>=7.2.2)"]
|
255 |
+
|
256 |
+
[[package]]
|
257 |
+
name = "fonttools"
|
258 |
+
version = "4.39.3"
|
259 |
+
description = "Tools to manipulate font files"
|
260 |
+
category = "main"
|
261 |
+
optional = false
|
262 |
+
python-versions = ">=3.8"
|
263 |
+
|
264 |
+
[package.extras]
|
265 |
+
all = ["fs (>=2.2.0,<3)", "lxml (>=4.0,<5)", "zopfli (>=0.1.4)", "lz4 (>=1.7.4.2)", "matplotlib", "sympy", "skia-pathops (>=0.5.0)", "uharfbuzz (>=0.23.0)", "brotlicffi (>=0.8.0)", "scipy", "brotli (>=1.0.1)", "munkres", "unicodedata2 (>=15.0.0)", "xattr"]
|
266 |
+
graphite = ["lz4 (>=1.7.4.2)"]
|
267 |
+
interpolatable = ["scipy", "munkres"]
|
268 |
+
lxml = ["lxml (>=4.0,<5)"]
|
269 |
+
pathops = ["skia-pathops (>=0.5.0)"]
|
270 |
+
plot = ["matplotlib"]
|
271 |
+
repacker = ["uharfbuzz (>=0.23.0)"]
|
272 |
+
symfont = ["sympy"]
|
273 |
+
type1 = ["xattr"]
|
274 |
+
ufo = ["fs (>=2.2.0,<3)"]
|
275 |
+
unicode = ["unicodedata2 (>=15.0.0)"]
|
276 |
+
woff = ["zopfli (>=0.1.4)", "brotlicffi (>=0.8.0)", "brotli (>=1.0.1)"]
|
277 |
+
|
278 |
+
[[package]]
|
279 |
+
name = "frozenlist"
|
280 |
+
version = "1.3.3"
|
281 |
+
description = "A list-like structure which implements collections.abc.MutableSequence"
|
282 |
+
category = "main"
|
283 |
+
optional = false
|
284 |
+
python-versions = ">=3.7"
|
285 |
+
|
286 |
+
[[package]]
|
287 |
+
name = "fsspec"
|
288 |
+
version = "2023.4.0"
|
289 |
+
description = "File-system specification"
|
290 |
+
category = "main"
|
291 |
+
optional = false
|
292 |
+
python-versions = ">=3.8"
|
293 |
+
|
294 |
+
[package.dependencies]
|
295 |
+
aiohttp = {version = "<4.0.0a0 || >4.0.0a0,<4.0.0a1 || >4.0.0a1", optional = true, markers = "extra == \"http\""}
|
296 |
+
requests = {version = "*", optional = true, markers = "extra == \"http\""}
|
297 |
+
|
298 |
+
[package.extras]
|
299 |
+
abfs = ["adlfs"]
|
300 |
+
adl = ["adlfs"]
|
301 |
+
arrow = ["pyarrow (>=1)"]
|
302 |
+
dask = ["dask", "distributed"]
|
303 |
+
devel = ["pytest", "pytest-cov"]
|
304 |
+
dropbox = ["dropboxdrivefs", "requests", "dropbox"]
|
305 |
+
full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"]
|
306 |
+
fuse = ["fusepy"]
|
307 |
+
gcs = ["gcsfs"]
|
308 |
+
git = ["pygit2"]
|
309 |
+
github = ["requests"]
|
310 |
+
gs = ["gcsfs"]
|
311 |
+
gui = ["panel"]
|
312 |
+
hdfs = ["pyarrow (>=1)"]
|
313 |
+
http = ["requests", "aiohttp (!=4.0.0a0,!=4.0.0a1)"]
|
314 |
+
libarchive = ["libarchive-c"]
|
315 |
+
oci = ["ocifs"]
|
316 |
+
s3 = ["s3fs"]
|
317 |
+
sftp = ["paramiko"]
|
318 |
+
smb = ["smbprotocol"]
|
319 |
+
ssh = ["paramiko"]
|
320 |
+
tqdm = ["tqdm"]
|
321 |
+
|
322 |
+
[[package]]
|
323 |
+
name = "google-api-core"
|
324 |
+
version = "2.11.0"
|
325 |
+
description = "Google API client core library"
|
326 |
+
category = "main"
|
327 |
+
optional = false
|
328 |
+
python-versions = ">=3.7"
|
329 |
+
|
330 |
+
[package.dependencies]
|
331 |
+
google-auth = ">=2.14.1,<3.0dev"
|
332 |
+
googleapis-common-protos = ">=1.56.2,<2.0dev"
|
333 |
+
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
334 |
+
requests = ">=2.18.0,<3.0.0dev"
|
335 |
+
|
336 |
+
[package.extras]
|
337 |
+
grpc = ["grpcio (>=1.33.2,<2.0dev)", "grpcio-status (>=1.33.2,<2.0dev)", "grpcio (>=1.49.1,<2.0dev)", "grpcio-status (>=1.49.1,<2.0dev)"]
|
338 |
+
grpcgcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
|
339 |
+
grpcio-gcp = ["grpcio-gcp (>=0.2.2,<1.0dev)"]
|
340 |
+
|
341 |
+
[[package]]
|
342 |
+
name = "google-api-python-client"
|
343 |
+
version = "2.85.0"
|
344 |
+
description = "Google API Client Library for Python"
|
345 |
+
category = "main"
|
346 |
+
optional = false
|
347 |
+
python-versions = ">=3.7"
|
348 |
+
|
349 |
+
[package.dependencies]
|
350 |
+
google-api-core = ">=1.31.5,<2.0.0 || >2.3.0,<3.0.0dev"
|
351 |
+
google-auth = ">=1.19.0,<3.0.0dev"
|
352 |
+
google-auth-httplib2 = ">=0.1.0"
|
353 |
+
httplib2 = ">=0.15.0,<1dev"
|
354 |
+
uritemplate = ">=3.0.1,<5"
|
355 |
+
|
356 |
+
[[package]]
|
357 |
+
name = "google-auth"
|
358 |
+
version = "2.17.3"
|
359 |
+
description = "Google Authentication Library"
|
360 |
+
category = "main"
|
361 |
+
optional = false
|
362 |
+
python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*"
|
363 |
+
|
364 |
+
[package.dependencies]
|
365 |
+
cachetools = ">=2.0.0,<6.0"
|
366 |
+
pyasn1-modules = ">=0.2.1"
|
367 |
+
rsa = {version = ">=3.1.4,<5", markers = "python_version >= \"3.6\""}
|
368 |
+
six = ">=1.9.0"
|
369 |
+
|
370 |
+
[package.extras]
|
371 |
+
aiohttp = ["requests (>=2.20.0,<3.0.0dev)", "aiohttp (>=3.6.2,<4.0.0dev)"]
|
372 |
+
enterprise_cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"]
|
373 |
+
pyopenssl = ["pyopenssl (>=20.0.0)", "cryptography (>=38.0.3)"]
|
374 |
+
reauth = ["pyu2f (>=0.1.5)"]
|
375 |
+
requests = ["requests (>=2.20.0,<3.0.0dev)"]
|
376 |
+
|
377 |
+
[[package]]
|
378 |
+
name = "google-auth-httplib2"
|
379 |
+
version = "0.1.0"
|
380 |
+
description = "Google Authentication Library: httplib2 transport"
|
381 |
+
category = "main"
|
382 |
+
optional = false
|
383 |
+
python-versions = "*"
|
384 |
+
|
385 |
+
[package.dependencies]
|
386 |
+
google-auth = "*"
|
387 |
+
httplib2 = ">=0.15.0"
|
388 |
+
six = "*"
|
389 |
+
|
390 |
+
[[package]]
|
391 |
+
name = "google-auth-oauthlib"
|
392 |
+
version = "0.4.6"
|
393 |
+
description = "Google Authentication Library"
|
394 |
+
category = "main"
|
395 |
+
optional = false
|
396 |
+
python-versions = ">=3.6"
|
397 |
+
|
398 |
+
[package.dependencies]
|
399 |
+
google-auth = ">=1.0.0"
|
400 |
+
requests-oauthlib = ">=0.7.0"
|
401 |
+
|
402 |
+
[package.extras]
|
403 |
+
tool = ["click (>=6.0.0)"]
|
404 |
+
|
405 |
+
[[package]]
|
406 |
+
name = "google-cloud-core"
|
407 |
+
version = "2.3.2"
|
408 |
+
description = "Google Cloud API client core library"
|
409 |
+
category = "main"
|
410 |
+
optional = false
|
411 |
+
python-versions = ">=3.7"
|
412 |
+
|
413 |
+
[package.dependencies]
|
414 |
+
google-api-core = ">=1.31.6,<2.0.0 || >2.3.0,<3.0.0dev"
|
415 |
+
google-auth = ">=1.25.0,<3.0dev"
|
416 |
+
|
417 |
+
[package.extras]
|
418 |
+
grpc = ["grpcio (>=1.38.0,<2.0dev)"]
|
419 |
+
|
420 |
+
[[package]]
|
421 |
+
name = "google-cloud-storage"
|
422 |
+
version = "2.8.0"
|
423 |
+
description = "Google Cloud Storage API client library"
|
424 |
+
category = "main"
|
425 |
+
optional = false
|
426 |
+
python-versions = ">=3.7"
|
427 |
+
|
428 |
+
[package.dependencies]
|
429 |
+
google-api-core = ">=1.31.5,<2.0.0 || >2.3.0,<3.0.0dev"
|
430 |
+
google-auth = ">=1.25.0,<3.0dev"
|
431 |
+
google-cloud-core = ">=2.3.0,<3.0dev"
|
432 |
+
google-resumable-media = ">=2.3.2"
|
433 |
+
requests = ">=2.18.0,<3.0.0dev"
|
434 |
+
|
435 |
+
[package.extras]
|
436 |
+
protobuf = ["protobuf (<5.0.0dev)"]
|
437 |
+
|
438 |
+
[[package]]
|
439 |
+
name = "google-crc32c"
|
440 |
+
version = "1.5.0"
|
441 |
+
description = "A python wrapper of the C library 'Google CRC32C'"
|
442 |
+
category = "main"
|
443 |
+
optional = false
|
444 |
+
python-versions = ">=3.7"
|
445 |
+
|
446 |
+
[package.extras]
|
447 |
+
testing = ["pytest"]
|
448 |
+
|
449 |
+
[[package]]
|
450 |
+
name = "google-resumable-media"
|
451 |
+
version = "2.4.1"
|
452 |
+
description = "Utilities for Google Media Downloads and Resumable Uploads"
|
453 |
+
category = "main"
|
454 |
+
optional = false
|
455 |
+
python-versions = ">= 3.7"
|
456 |
+
|
457 |
+
[package.dependencies]
|
458 |
+
google-crc32c = ">=1.0,<2.0dev"
|
459 |
+
|
460 |
+
[package.extras]
|
461 |
+
aiohttp = ["aiohttp (>=3.6.2,<4.0.0dev)"]
|
462 |
+
requests = ["requests (>=2.18.0,<3.0.0dev)"]
|
463 |
+
|
464 |
+
[[package]]
|
465 |
+
name = "googleapis-common-protos"
|
466 |
+
version = "1.59.0"
|
467 |
+
description = "Common protobufs used in Google APIs"
|
468 |
+
category = "main"
|
469 |
+
optional = false
|
470 |
+
python-versions = ">=3.7"
|
471 |
+
|
472 |
+
[package.dependencies]
|
473 |
+
protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0dev"
|
474 |
+
|
475 |
+
[package.extras]
|
476 |
+
grpc = ["grpcio (>=1.44.0,<2.0.0dev)"]
|
477 |
+
|
478 |
+
[[package]]
|
479 |
+
name = "gradio"
|
480 |
+
version = "3.27.0"
|
481 |
+
description = "Python library for easily interacting with trained machine learning models"
|
482 |
+
category = "main"
|
483 |
+
optional = false
|
484 |
+
python-versions = ">=3.7"
|
485 |
+
|
486 |
+
[package.dependencies]
|
487 |
+
aiofiles = "*"
|
488 |
+
aiohttp = "*"
|
489 |
+
altair = ">=4.2.0"
|
490 |
+
fastapi = "*"
|
491 |
+
ffmpy = "*"
|
492 |
+
gradio-client = ">=0.1.3"
|
493 |
+
httpx = "*"
|
494 |
+
huggingface-hub = ">=0.13.0"
|
495 |
+
jinja2 = "*"
|
496 |
+
markdown-it-py = {version = ">=2.0.0", extras = ["linkify"]}
|
497 |
+
markupsafe = "*"
|
498 |
+
matplotlib = "*"
|
499 |
+
mdit-py-plugins = "<=0.3.3"
|
500 |
+
numpy = "*"
|
501 |
+
orjson = "*"
|
502 |
+
pandas = "*"
|
503 |
+
pillow = "*"
|
504 |
+
pydantic = "*"
|
505 |
+
pydub = "*"
|
506 |
+
python-multipart = "*"
|
507 |
+
pyyaml = "*"
|
508 |
+
requests = "*"
|
509 |
+
semantic-version = "*"
|
510 |
+
typing-extensions = "*"
|
511 |
+
uvicorn = "*"
|
512 |
+
websockets = ">=10.0"
|
513 |
+
|
514 |
+
[[package]]
|
515 |
+
name = "gradio-client"
|
516 |
+
version = "0.1.3"
|
517 |
+
description = "Python library for easily interacting with trained machine learning models"
|
518 |
+
category = "main"
|
519 |
+
optional = false
|
520 |
+
python-versions = ">=3.7"
|
521 |
+
|
522 |
+
[package.dependencies]
|
523 |
+
fsspec = "*"
|
524 |
+
httpx = "*"
|
525 |
+
huggingface-hub = ">=0.13.0"
|
526 |
+
packaging = "*"
|
527 |
+
requests = "*"
|
528 |
+
typing-extensions = "*"
|
529 |
+
websockets = "*"
|
530 |
+
|
531 |
+
[[package]]
|
532 |
+
name = "grpcio"
|
533 |
+
version = "1.53.0"
|
534 |
+
description = "HTTP/2-based RPC framework"
|
535 |
+
category = "main"
|
536 |
+
optional = false
|
537 |
+
python-versions = ">=3.7"
|
538 |
+
|
539 |
+
[package.extras]
|
540 |
+
protobuf = ["grpcio-tools (>=1.53.0)"]
|
541 |
+
|
542 |
+
[[package]]
|
543 |
+
name = "h11"
|
544 |
+
version = "0.14.0"
|
545 |
+
description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1"
|
546 |
+
category = "main"
|
547 |
+
optional = false
|
548 |
+
python-versions = ">=3.7"
|
549 |
+
|
550 |
+
[[package]]
|
551 |
+
name = "httpcore"
|
552 |
+
version = "0.17.0"
|
553 |
+
description = "A minimal low-level HTTP client."
|
554 |
+
category = "main"
|
555 |
+
optional = false
|
556 |
+
python-versions = ">=3.7"
|
557 |
+
|
558 |
+
[package.dependencies]
|
559 |
+
anyio = ">=3.0,<5.0"
|
560 |
+
certifi = "*"
|
561 |
+
h11 = ">=0.13,<0.15"
|
562 |
+
sniffio = ">=1.0.0,<2.0.0"
|
563 |
+
|
564 |
+
[package.extras]
|
565 |
+
http2 = ["h2 (>=3,<5)"]
|
566 |
+
socks = ["socksio (>=1.0.0,<2.0.0)"]
|
567 |
+
|
568 |
+
[[package]]
|
569 |
+
name = "httplib2"
|
570 |
+
version = "0.22.0"
|
571 |
+
description = "A comprehensive HTTP client library."
|
572 |
+
category = "main"
|
573 |
+
optional = false
|
574 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
575 |
+
|
576 |
+
[package.dependencies]
|
577 |
+
pyparsing = {version = ">=2.4.2,<3.0.0 || >3.0.0,<3.0.1 || >3.0.1,<3.0.2 || >3.0.2,<3.0.3 || >3.0.3,<4", markers = "python_version > \"3.0\""}
|
578 |
+
|
579 |
+
[[package]]
|
580 |
+
name = "httpx"
|
581 |
+
version = "0.24.0"
|
582 |
+
description = "The next generation HTTP client."
|
583 |
+
category = "main"
|
584 |
+
optional = false
|
585 |
+
python-versions = ">=3.7"
|
586 |
+
|
587 |
+
[package.dependencies]
|
588 |
+
certifi = "*"
|
589 |
+
httpcore = ">=0.15.0,<0.18.0"
|
590 |
+
idna = "*"
|
591 |
+
sniffio = "*"
|
592 |
+
|
593 |
+
[package.extras]
|
594 |
+
brotli = ["brotli", "brotlicffi"]
|
595 |
+
cli = ["click (>=8.0.0,<9.0.0)", "pygments (>=2.0.0,<3.0.0)", "rich (>=10,<14)"]
|
596 |
+
http2 = ["h2 (>=3,<5)"]
|
597 |
+
socks = ["socksio (>=1.0.0,<2.0.0)"]
|
598 |
+
|
599 |
+
[[package]]
|
600 |
+
name = "huggingface-hub"
|
601 |
+
version = "0.13.4"
|
602 |
+
description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
|
603 |
+
category = "main"
|
604 |
+
optional = false
|
605 |
+
python-versions = ">=3.7.0"
|
606 |
+
|
607 |
+
[package.dependencies]
|
608 |
+
filelock = "*"
|
609 |
+
packaging = ">=20.9"
|
610 |
+
pyyaml = ">=5.1"
|
611 |
+
requests = "*"
|
612 |
+
tqdm = ">=4.42.1"
|
613 |
+
typing-extensions = ">=3.7.4.3"
|
614 |
+
|
615 |
+
[package.extras]
|
616 |
+
all = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
617 |
+
cli = ["InquirerPy (==0.3.4)"]
|
618 |
+
dev = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow", "black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)", "types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
619 |
+
fastai = ["toml", "fastai (>=2.4)", "fastcore (>=1.3.27)"]
|
620 |
+
quality = ["black (>=23.1,<24.0)", "ruff (>=0.0.241)", "mypy (==0.982)"]
|
621 |
+
tensorflow = ["tensorflow", "pydot", "graphviz"]
|
622 |
+
testing = ["InquirerPy (==0.3.4)", "jedi", "jinja2", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "pillow"]
|
623 |
+
torch = ["torch"]
|
624 |
+
typing = ["types-pyyaml", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
|
625 |
+
|
626 |
+
[[package]]
|
627 |
+
name = "hydra-client"
|
628 |
+
version = "0.5.1"
|
629 |
+
description = "Client library for ORY Hydra (OAuth2 and OpenID Connect provider)"
|
630 |
+
category = "main"
|
631 |
+
optional = false
|
632 |
+
python-versions = ">=3.7,<4.0"
|
633 |
+
|
634 |
+
[package.dependencies]
|
635 |
+
attrs = ">=19.2,<20.0"
|
636 |
+
python-dateutil = ">=2.8,<3.0"
|
637 |
+
requests = ">=2.21,<3.0"
|
638 |
+
requests-oauthlib = ">=1.0,<2.0"
|
639 |
+
|
640 |
+
[[package]]
|
641 |
+
name = "hydra-core"
|
642 |
+
version = "1.3.1"
|
643 |
+
description = "A framework for elegantly configuring complex applications"
|
644 |
+
category = "main"
|
645 |
+
optional = false
|
646 |
+
python-versions = "*"
|
647 |
+
|
648 |
+
[package.dependencies]
|
649 |
+
antlr4-python3-runtime = ">=4.9.0,<4.10.0"
|
650 |
+
omegaconf = ">=2.2,<2.4"
|
651 |
+
packaging = "*"
|
652 |
+
|
653 |
+
[[package]]
|
654 |
+
name = "idna"
|
655 |
+
version = "3.4"
|
656 |
+
description = "Internationalized Domain Names in Applications (IDNA)"
|
657 |
+
category = "main"
|
658 |
+
optional = false
|
659 |
+
python-versions = ">=3.5"
|
660 |
+
|
661 |
+
[[package]]
|
662 |
+
name = "jinja2"
|
663 |
+
version = "3.1.2"
|
664 |
+
description = "A very fast and expressive template engine."
|
665 |
+
category = "main"
|
666 |
+
optional = false
|
667 |
+
python-versions = ">=3.7"
|
668 |
+
|
669 |
+
[package.dependencies]
|
670 |
+
MarkupSafe = ">=2.0"
|
671 |
+
|
672 |
+
[package.extras]
|
673 |
+
i18n = ["Babel (>=2.7)"]
|
674 |
+
|
675 |
+
[[package]]
|
676 |
+
name = "jsonschema"
|
677 |
+
version = "4.17.3"
|
678 |
+
description = "An implementation of JSON Schema validation for Python"
|
679 |
+
category = "main"
|
680 |
+
optional = false
|
681 |
+
python-versions = ">=3.7"
|
682 |
+
|
683 |
+
[package.dependencies]
|
684 |
+
attrs = ">=17.4.0"
|
685 |
+
pyrsistent = ">=0.14.0,<0.17.0 || >0.17.0,<0.17.1 || >0.17.1,<0.17.2 || >0.17.2"
|
686 |
+
|
687 |
+
[package.extras]
|
688 |
+
format = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3987", "uri-template", "webcolors (>=1.11)"]
|
689 |
+
format-nongpl = ["fqdn", "idna", "isoduration", "jsonpointer (>1.13)", "rfc3339-validator", "rfc3986-validator (>0.1.0)", "uri-template", "webcolors (>=1.11)"]
|
690 |
+
|
691 |
+
[[package]]
|
692 |
+
name = "kiwisolver"
|
693 |
+
version = "1.4.4"
|
694 |
+
description = "A fast implementation of the Cassowary constraint solver"
|
695 |
+
category = "main"
|
696 |
+
optional = false
|
697 |
+
python-versions = ">=3.7"
|
698 |
+
|
699 |
+
[[package]]
|
700 |
+
name = "lightning-utilities"
|
701 |
+
version = "0.8.0"
|
702 |
+
description = "PyTorch Lightning Sample project."
|
703 |
+
category = "main"
|
704 |
+
optional = false
|
705 |
+
python-versions = ">=3.7"
|
706 |
+
|
707 |
+
[package.dependencies]
|
708 |
+
packaging = ">=17.1"
|
709 |
+
typing-extensions = "*"
|
710 |
+
|
711 |
+
[package.extras]
|
712 |
+
cli = ["fire"]
|
713 |
+
docs = ["sphinx (>=4.0,<5.0)"]
|
714 |
+
test = ["coverage (==6.5.0)"]
|
715 |
+
typing = ["mypy (>=1.0.0)"]
|
716 |
+
|
717 |
+
[[package]]
|
718 |
+
name = "linkify-it-py"
|
719 |
+
version = "2.0.0"
|
720 |
+
description = "Links recognition library with FULL unicode support."
|
721 |
+
category = "main"
|
722 |
+
optional = false
|
723 |
+
python-versions = ">=3.6"
|
724 |
+
|
725 |
+
[package.dependencies]
|
726 |
+
uc-micro-py = "*"
|
727 |
+
|
728 |
+
[package.extras]
|
729 |
+
benchmark = ["pytest", "pytest-benchmark"]
|
730 |
+
dev = ["pre-commit", "isort", "flake8", "black"]
|
731 |
+
doc = ["sphinx", "sphinx-book-theme", "myst-parser"]
|
732 |
+
test = ["coverage", "pytest", "pytest-cov"]
|
733 |
+
|
734 |
+
[[package]]
|
735 |
+
name = "markdown"
|
736 |
+
version = "3.4.3"
|
737 |
+
description = "Python implementation of John Gruber's Markdown."
|
738 |
+
category = "main"
|
739 |
+
optional = false
|
740 |
+
python-versions = ">=3.7"
|
741 |
+
|
742 |
+
[package.extras]
|
743 |
+
testing = ["coverage", "pyyaml"]
|
744 |
+
|
745 |
+
[[package]]
|
746 |
+
name = "markdown-it-py"
|
747 |
+
version = "2.2.0"
|
748 |
+
description = "Python port of markdown-it. Markdown parsing, done right!"
|
749 |
+
category = "main"
|
750 |
+
optional = false
|
751 |
+
python-versions = ">=3.7"
|
752 |
+
|
753 |
+
[package.dependencies]
|
754 |
+
linkify-it-py = {version = ">=1,<3", optional = true, markers = "extra == \"linkify\""}
|
755 |
+
mdurl = ">=0.1,<1.0"
|
756 |
+
|
757 |
+
[package.extras]
|
758 |
+
benchmarking = ["psutil", "pytest", "pytest-benchmark"]
|
759 |
+
code_style = ["pre-commit (>=3.0,<4.0)"]
|
760 |
+
compare = ["commonmark (>=0.9,<1.0)", "markdown (>=3.4,<4.0)", "mistletoe (>=1.0,<2.0)", "mistune (>=2.0,<3.0)", "panflute (>=2.3,<3.0)"]
|
761 |
+
linkify = ["linkify-it-py (>=1,<3)"]
|
762 |
+
plugins = ["mdit-py-plugins"]
|
763 |
+
profiling = ["gprof2dot"]
|
764 |
+
rtd = ["attrs", "myst-parser", "pyyaml", "sphinx", "sphinx-copybutton", "sphinx-design", "sphinx-book-theme"]
|
765 |
+
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
766 |
+
|
767 |
+
[[package]]
|
768 |
+
name = "markupsafe"
|
769 |
+
version = "2.1.2"
|
770 |
+
description = "Safely add untrusted strings to HTML/XML markup."
|
771 |
+
category = "main"
|
772 |
+
optional = false
|
773 |
+
python-versions = ">=3.7"
|
774 |
+
|
775 |
+
[[package]]
|
776 |
+
name = "matplotlib"
|
777 |
+
version = "3.7.1"
|
778 |
+
description = "Python plotting package"
|
779 |
+
category = "main"
|
780 |
+
optional = false
|
781 |
+
python-versions = ">=3.8"
|
782 |
+
|
783 |
+
[package.dependencies]
|
784 |
+
contourpy = ">=1.0.1"
|
785 |
+
cycler = ">=0.10"
|
786 |
+
fonttools = ">=4.22.0"
|
787 |
+
kiwisolver = ">=1.0.1"
|
788 |
+
numpy = ">=1.20"
|
789 |
+
packaging = ">=20.0"
|
790 |
+
pillow = ">=6.2.0"
|
791 |
+
pyparsing = ">=2.3.1"
|
792 |
+
python-dateutil = ">=2.7"
|
793 |
+
setuptools_scm = ">=7"
|
794 |
+
|
795 |
+
[[package]]
|
796 |
+
name = "mdit-py-plugins"
|
797 |
+
version = "0.3.3"
|
798 |
+
description = "Collection of plugins for markdown-it-py"
|
799 |
+
category = "main"
|
800 |
+
optional = false
|
801 |
+
python-versions = ">=3.7"
|
802 |
+
|
803 |
+
[package.dependencies]
|
804 |
+
markdown-it-py = ">=1.0.0,<3.0.0"
|
805 |
+
|
806 |
+
[package.extras]
|
807 |
+
code_style = ["pre-commit"]
|
808 |
+
rtd = ["attrs", "myst-parser (>=0.16.1,<0.17.0)", "sphinx-book-theme (>=0.1.0,<0.2.0)"]
|
809 |
+
testing = ["coverage", "pytest", "pytest-cov", "pytest-regressions"]
|
810 |
+
|
811 |
+
[[package]]
|
812 |
+
name = "mdurl"
|
813 |
+
version = "0.1.2"
|
814 |
+
description = "Markdown URL utilities"
|
815 |
+
category = "main"
|
816 |
+
optional = false
|
817 |
+
python-versions = ">=3.7"
|
818 |
+
|
819 |
+
[[package]]
|
820 |
+
name = "multidict"
|
821 |
+
version = "6.0.4"
|
822 |
+
description = "multidict implementation"
|
823 |
+
category = "main"
|
824 |
+
optional = false
|
825 |
+
python-versions = ">=3.7"
|
826 |
+
|
827 |
+
[[package]]
|
828 |
+
name = "numpy"
|
829 |
+
version = "1.24.2"
|
830 |
+
description = "Fundamental package for array computing in Python"
|
831 |
+
category = "main"
|
832 |
+
optional = false
|
833 |
+
python-versions = ">=3.8"
|
834 |
+
|
835 |
+
[[package]]
|
836 |
+
name = "nvidia-cublas-cu11"
|
837 |
+
version = "11.10.3.66"
|
838 |
+
description = "CUBLAS native runtime libraries"
|
839 |
+
category = "main"
|
840 |
+
optional = false
|
841 |
+
python-versions = ">=3"
|
842 |
+
|
843 |
+
[[package]]
|
844 |
+
name = "nvidia-cuda-nvrtc-cu11"
|
845 |
+
version = "11.7.99"
|
846 |
+
description = "NVRTC native runtime libraries"
|
847 |
+
category = "main"
|
848 |
+
optional = false
|
849 |
+
python-versions = ">=3"
|
850 |
+
|
851 |
+
[[package]]
|
852 |
+
name = "nvidia-cuda-runtime-cu11"
|
853 |
+
version = "11.7.99"
|
854 |
+
description = "CUDA Runtime native Libraries"
|
855 |
+
category = "main"
|
856 |
+
optional = false
|
857 |
+
python-versions = ">=3"
|
858 |
+
|
859 |
+
[[package]]
|
860 |
+
name = "nvidia-cudnn-cu11"
|
861 |
+
version = "8.5.0.96"
|
862 |
+
description = "cuDNN runtime libraries"
|
863 |
+
category = "main"
|
864 |
+
optional = false
|
865 |
+
python-versions = ">=3"
|
866 |
+
|
867 |
+
[[package]]
|
868 |
+
name = "oauthlib"
|
869 |
+
version = "3.2.2"
|
870 |
+
description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic"
|
871 |
+
category = "main"
|
872 |
+
optional = false
|
873 |
+
python-versions = ">=3.6"
|
874 |
+
|
875 |
+
[package.extras]
|
876 |
+
rsa = ["cryptography (>=3.0.0)"]
|
877 |
+
signals = ["blinker (>=1.4.0)"]
|
878 |
+
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
|
879 |
+
|
880 |
+
[[package]]
|
881 |
+
name = "omegaconf"
|
882 |
+
version = "2.3.0"
|
883 |
+
description = "A flexible configuration library"
|
884 |
+
category = "main"
|
885 |
+
optional = false
|
886 |
+
python-versions = ">=3.6"
|
887 |
+
|
888 |
+
[package.dependencies]
|
889 |
+
antlr4-python3-runtime = ">=4.9.0,<4.10.0"
|
890 |
+
PyYAML = ">=5.1.0"
|
891 |
+
|
892 |
+
[[package]]
|
893 |
+
name = "opencv-python"
|
894 |
+
version = "4.7.0.72"
|
895 |
+
description = "Wrapper package for OpenCV python bindings."
|
896 |
+
category = "main"
|
897 |
+
optional = false
|
898 |
+
python-versions = ">=3.6"
|
899 |
+
|
900 |
+
[package.dependencies]
|
901 |
+
numpy = [
|
902 |
+
{version = ">=1.21.2", markers = "python_version >= \"3.10\""},
|
903 |
+
{version = ">=1.21.4", markers = "python_version >= \"3.10\" and platform_system == \"Darwin\""},
|
904 |
+
{version = ">=1.22.0", markers = "python_version >= \"3.11\""},
|
905 |
+
{version = ">=1.19.3", markers = "python_version >= \"3.6\" and platform_system == \"Linux\" and platform_machine == \"aarch64\" or python_version >= \"3.9\""},
|
906 |
+
{version = ">=1.17.0", markers = "python_version >= \"3.7\""},
|
907 |
+
{version = ">=1.17.3", markers = "python_version >= \"3.8\""},
|
908 |
+
]
|
909 |
+
|
910 |
+
[[package]]
|
911 |
+
name = "orjson"
|
912 |
+
version = "3.8.10"
|
913 |
+
description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy"
|
914 |
+
category = "main"
|
915 |
+
optional = false
|
916 |
+
python-versions = ">= 3.7"
|
917 |
+
|
918 |
+
[[package]]
|
919 |
+
name = "packaging"
|
920 |
+
version = "23.1"
|
921 |
+
description = "Core utilities for Python packages"
|
922 |
+
category = "main"
|
923 |
+
optional = false
|
924 |
+
python-versions = ">=3.7"
|
925 |
+
|
926 |
+
[[package]]
|
927 |
+
name = "pandas"
|
928 |
+
version = "2.0.0"
|
929 |
+
description = "Powerful data structures for data analysis, time series, and statistics"
|
930 |
+
category = "main"
|
931 |
+
optional = false
|
932 |
+
python-versions = ">=3.8"
|
933 |
+
|
934 |
+
[package.dependencies]
|
935 |
+
numpy = [
|
936 |
+
{version = ">=1.21.0", markers = "python_version >= \"3.10\""},
|
937 |
+
{version = ">=1.23.2", markers = "python_version >= \"3.11\""},
|
938 |
+
]
|
939 |
+
python-dateutil = ">=2.8.2"
|
940 |
+
pytz = ">=2020.1"
|
941 |
+
tzdata = ">=2022.1"
|
942 |
+
|
943 |
+
[package.extras]
|
944 |
+
all = ["beautifulsoup4 (>=4.9.3)", "bottleneck (>=1.3.2)", "brotlipy (>=0.7.0)", "fastparquet (>=0.6.3)", "fsspec (>=2021.07.0)", "gcsfs (>=2021.07.0)", "html5lib (>=1.1)", "hypothesis (>=6.34.2)", "jinja2 (>=3.0.0)", "lxml (>=4.6.3)", "matplotlib (>=3.6.1)", "numba (>=0.53.1)", "numexpr (>=2.7.3)", "odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pandas-gbq (>=0.15.0)", "psycopg2 (>=2.8.6)", "pyarrow (>=7.0.0)", "pymysql (>=1.0.2)", "PyQt5 (>=5.15.1)", "pyreadstat (>=1.1.2)", "pytest (>=7.0.0)", "pytest-xdist (>=2.2.0)", "pytest-asyncio (>=0.17.0)", "python-snappy (>=0.6.0)", "pyxlsb (>=1.0.8)", "qtpy (>=2.2.0)", "scipy (>=1.7.1)", "s3fs (>=2021.08.0)", "SQLAlchemy (>=1.4.16)", "tables (>=3.6.1)", "tabulate (>=0.8.9)", "xarray (>=0.21.0)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)", "zstandard (>=0.15.2)"]
|
945 |
+
aws = ["s3fs (>=2021.08.0)"]
|
946 |
+
clipboard = ["PyQt5 (>=5.15.1)", "qtpy (>=2.2.0)"]
|
947 |
+
compression = ["brotlipy (>=0.7.0)", "python-snappy (>=0.6.0)", "zstandard (>=0.15.2)"]
|
948 |
+
computation = ["scipy (>=1.7.1)", "xarray (>=0.21.0)"]
|
949 |
+
excel = ["odfpy (>=1.4.1)", "openpyxl (>=3.0.7)", "pyxlsb (>=1.0.8)", "xlrd (>=2.0.1)", "xlsxwriter (>=1.4.3)"]
|
950 |
+
feather = ["pyarrow (>=7.0.0)"]
|
951 |
+
fss = ["fsspec (>=2021.07.0)"]
|
952 |
+
gcp = ["gcsfs (>=2021.07.0)", "pandas-gbq (>=0.15.0)"]
|
953 |
+
hdf5 = ["tables (>=3.6.1)"]
|
954 |
+
html = ["beautifulsoup4 (>=4.9.3)", "html5lib (>=1.1)", "lxml (>=4.6.3)"]
|
955 |
+
mysql = ["SQLAlchemy (>=1.4.16)", "pymysql (>=1.0.2)"]
|
956 |
+
output_formatting = ["jinja2 (>=3.0.0)", "tabulate (>=0.8.9)"]
|
957 |
+
parquet = ["pyarrow (>=7.0.0)"]
|
958 |
+
performance = ["bottleneck (>=1.3.2)", "numba (>=0.53.1)", "numexpr (>=2.7.1)"]
|
959 |
+
plot = ["matplotlib (>=3.6.1)"]
|
960 |
+
postgresql = ["SQLAlchemy (>=1.4.16)", "psycopg2 (>=2.8.6)"]
|
961 |
+
spss = ["pyreadstat (>=1.1.2)"]
|
962 |
+
sql-other = ["SQLAlchemy (>=1.4.16)"]
|
963 |
+
test = ["hypothesis (>=6.34.2)", "pytest (>=7.0.0)", "pytest-xdist (>=2.2.0)", "pytest-asyncio (>=0.17.0)"]
|
964 |
+
xml = ["lxml (>=4.6.3)"]
|
965 |
+
|
966 |
+
[[package]]
|
967 |
+
name = "pillow"
|
968 |
+
version = "8.4.0"
|
969 |
+
description = "Python Imaging Library (Fork)"
|
970 |
+
category = "main"
|
971 |
+
optional = false
|
972 |
+
python-versions = ">=3.6"
|
973 |
+
|
974 |
+
[[package]]
|
975 |
+
name = "plotly"
|
976 |
+
version = "5.14.1"
|
977 |
+
description = "An open-source, interactive data visualization library for Python"
|
978 |
+
category = "main"
|
979 |
+
optional = false
|
980 |
+
python-versions = ">=3.6"
|
981 |
+
|
982 |
+
[package.dependencies]
|
983 |
+
packaging = "*"
|
984 |
+
tenacity = ">=6.2.0"
|
985 |
+
|
986 |
+
[[package]]
|
987 |
+
name = "protobuf"
|
988 |
+
version = "3.20.3"
|
989 |
+
description = "Protocol Buffers"
|
990 |
+
category = "main"
|
991 |
+
optional = false
|
992 |
+
python-versions = ">=3.7"
|
993 |
+
|
994 |
+
[[package]]
|
995 |
+
name = "pyasn1"
|
996 |
+
version = "0.4.8"
|
997 |
+
description = "ASN.1 types and codecs"
|
998 |
+
category = "main"
|
999 |
+
optional = false
|
1000 |
+
python-versions = "*"
|
1001 |
+
|
1002 |
+
[[package]]
|
1003 |
+
name = "pyasn1-modules"
|
1004 |
+
version = "0.2.8"
|
1005 |
+
description = "A collection of ASN.1-based protocols modules."
|
1006 |
+
category = "main"
|
1007 |
+
optional = false
|
1008 |
+
python-versions = "*"
|
1009 |
+
|
1010 |
+
[package.dependencies]
|
1011 |
+
pyasn1 = ">=0.4.6,<0.5.0"
|
1012 |
+
|
1013 |
+
[[package]]
|
1014 |
+
name = "pydantic"
|
1015 |
+
version = "1.10.7"
|
1016 |
+
description = "Data validation and settings management using python type hints"
|
1017 |
+
category = "main"
|
1018 |
+
optional = false
|
1019 |
+
python-versions = ">=3.7"
|
1020 |
+
|
1021 |
+
[package.dependencies]
|
1022 |
+
typing-extensions = ">=4.2.0"
|
1023 |
+
|
1024 |
+
[package.extras]
|
1025 |
+
dotenv = ["python-dotenv (>=0.10.4)"]
|
1026 |
+
email = ["email-validator (>=1.0.3)"]
|
1027 |
+
|
1028 |
+
[[package]]
|
1029 |
+
name = "pydub"
|
1030 |
+
version = "0.25.1"
|
1031 |
+
description = "Manipulate audio with an simple and easy high level interface"
|
1032 |
+
category = "main"
|
1033 |
+
optional = false
|
1034 |
+
python-versions = "*"
|
1035 |
+
|
1036 |
+
[[package]]
|
1037 |
+
name = "pyparsing"
|
1038 |
+
version = "3.0.9"
|
1039 |
+
description = "pyparsing module - Classes and methods to define and execute parsing grammars"
|
1040 |
+
category = "main"
|
1041 |
+
optional = false
|
1042 |
+
python-versions = ">=3.6.8"
|
1043 |
+
|
1044 |
+
[package.extras]
|
1045 |
+
diagrams = ["railroad-diagrams", "jinja2"]
|
1046 |
+
|
1047 |
+
[[package]]
|
1048 |
+
name = "pyrsistent"
|
1049 |
+
version = "0.19.3"
|
1050 |
+
description = "Persistent/Functional/Immutable data structures"
|
1051 |
+
category = "main"
|
1052 |
+
optional = false
|
1053 |
+
python-versions = ">=3.7"
|
1054 |
+
|
1055 |
+
[[package]]
|
1056 |
+
name = "python-dateutil"
|
1057 |
+
version = "2.8.2"
|
1058 |
+
description = "Extensions to the standard Python datetime module"
|
1059 |
+
category = "main"
|
1060 |
+
optional = false
|
1061 |
+
python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7"
|
1062 |
+
|
1063 |
+
[package.dependencies]
|
1064 |
+
six = ">=1.5"
|
1065 |
+
|
1066 |
+
[[package]]
|
1067 |
+
name = "python-multipart"
|
1068 |
+
version = "0.0.6"
|
1069 |
+
description = "A streaming multipart parser for Python"
|
1070 |
+
category = "main"
|
1071 |
+
optional = false
|
1072 |
+
python-versions = ">=3.7"
|
1073 |
+
|
1074 |
+
[package.extras]
|
1075 |
+
dev = ["atomicwrites (==1.2.1)", "attrs (==19.2.0)", "coverage (==6.5.0)", "hatch", "invoke (==1.7.3)", "more-itertools (==4.3.0)", "pbr (==4.3.0)", "pluggy (==1.0.0)", "py (==1.11.0)", "pytest-cov (==4.0.0)", "pytest-timeout (==2.1.0)", "pytest (==7.2.0)", "pyyaml (==5.1)"]
|
1076 |
+
|
1077 |
+
[[package]]
|
1078 |
+
name = "pytorch-lightning"
|
1079 |
+
version = "1.9.0"
|
1080 |
+
description = "PyTorch Lightning is the lightweight PyTorch wrapper for ML researchers. Scale your models. Write less boilerplate."
|
1081 |
+
category = "main"
|
1082 |
+
optional = false
|
1083 |
+
python-versions = ">=3.7"
|
1084 |
+
|
1085 |
+
[package.dependencies]
|
1086 |
+
fsspec = {version = ">2021.06.0", extras = ["http"]}
|
1087 |
+
lightning-utilities = ">=0.4.2"
|
1088 |
+
numpy = ">=1.17.2"
|
1089 |
+
packaging = ">=17.1"
|
1090 |
+
PyYAML = ">=5.4"
|
1091 |
+
torch = ">=1.10.0"
|
1092 |
+
torchmetrics = ">=0.7.0"
|
1093 |
+
tqdm = ">=4.57.0"
|
1094 |
+
typing-extensions = ">=4.0.0"
|
1095 |
+
|
1096 |
+
[package.extras]
|
1097 |
+
all = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)", "hivemind (==1.1.5)"]
|
1098 |
+
deepspeed = ["deepspeed (>=0.6.0)"]
|
1099 |
+
dev = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)", "fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)", "coverage (==6.5.0)", "codecov (==2.1.12)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "pre-commit (==2.20.0)", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime (<1.14.0)", "psutil (<5.9.5)", "pandas (>1.0)", "fastapi (<0.87.0)", "uvicorn (<0.19.1)", "tensorboard (>=2.9.1)", "protobuf (<=3.20.1)", "hivemind (==1.1.5)"]
|
1100 |
+
examples = ["torchvision (>=0.11.1)", "gym[classic_control] (>=0.17.0)", "ipython[all] (<8.7.1)"]
|
1101 |
+
extra = ["matplotlib (>3.1)", "omegaconf (>=2.0.5)", "hydra-core (>=1.0.5)", "jsonargparse[signatures] (>=4.18.0)", "rich (>=10.14.0,!=10.15.0.a)", "tensorboardX (>=2.2)"]
|
1102 |
+
fairscale = ["fairscale (>=0.4.5)"]
|
1103 |
+
hivemind = ["hivemind (==1.1.5)"]
|
1104 |
+
horovod = ["horovod (>=0.21.2,!=0.24.0)"]
|
1105 |
+
strategies = ["fairscale (>=0.4.5)", "deepspeed (>=0.6.0)", "horovod (>=0.21.2,!=0.24.0)", "hivemind (==1.1.5)"]
|
1106 |
+
test = ["coverage (==6.5.0)", "codecov (==2.1.12)", "pytest (==7.2.0)", "pytest-cov (==4.0.0)", "pytest-forked (==1.4.0)", "pytest-rerunfailures (==10.3)", "pre-commit (==2.20.0)", "cloudpickle (>=1.3)", "scikit-learn (>0.22.1)", "onnxruntime (<1.14.0)", "psutil (<5.9.5)", "pandas (>1.0)", "fastapi (<0.87.0)", "uvicorn (<0.19.1)", "tensorboard (>=2.9.1)", "protobuf (<=3.20.1)"]
|
1107 |
+
|
1108 |
+
[[package]]
|
1109 |
+
name = "pytz"
|
1110 |
+
version = "2023.3"
|
1111 |
+
description = "World timezone definitions, modern and historical"
|
1112 |
+
category = "main"
|
1113 |
+
optional = false
|
1114 |
+
python-versions = "*"
|
1115 |
+
|
1116 |
+
[[package]]
|
1117 |
+
name = "pyyaml"
|
1118 |
+
version = "6.0"
|
1119 |
+
description = "YAML parser and emitter for Python"
|
1120 |
+
category = "main"
|
1121 |
+
optional = false
|
1122 |
+
python-versions = ">=3.6"
|
1123 |
+
|
1124 |
+
[[package]]
|
1125 |
+
name = "requests"
|
1126 |
+
version = "2.28.2"
|
1127 |
+
description = "Python HTTP for Humans."
|
1128 |
+
category = "main"
|
1129 |
+
optional = false
|
1130 |
+
python-versions = ">=3.7, <4"
|
1131 |
+
|
1132 |
+
[package.dependencies]
|
1133 |
+
certifi = ">=2017.4.17"
|
1134 |
+
charset-normalizer = ">=2,<4"
|
1135 |
+
idna = ">=2.5,<4"
|
1136 |
+
urllib3 = ">=1.21.1,<1.27"
|
1137 |
+
|
1138 |
+
[package.extras]
|
1139 |
+
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
1140 |
+
use_chardet_on_py3 = ["chardet (>=3.0.2,<6)"]
|
1141 |
+
|
1142 |
+
[[package]]
|
1143 |
+
name = "requests-oauthlib"
|
1144 |
+
version = "1.3.1"
|
1145 |
+
description = "OAuthlib authentication support for Requests."
|
1146 |
+
category = "main"
|
1147 |
+
optional = false
|
1148 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*"
|
1149 |
+
|
1150 |
+
[package.dependencies]
|
1151 |
+
oauthlib = ">=3.0.0"
|
1152 |
+
requests = ">=2.0.0"
|
1153 |
+
|
1154 |
+
[package.extras]
|
1155 |
+
rsa = ["oauthlib[signedtoken] (>=3.0.0)"]
|
1156 |
+
|
1157 |
+
[[package]]
|
1158 |
+
name = "rsa"
|
1159 |
+
version = "4.9"
|
1160 |
+
description = "Pure-Python RSA implementation"
|
1161 |
+
category = "main"
|
1162 |
+
optional = false
|
1163 |
+
python-versions = ">=3.6,<4"
|
1164 |
+
|
1165 |
+
[package.dependencies]
|
1166 |
+
pyasn1 = ">=0.1.3"
|
1167 |
+
|
1168 |
+
[[package]]
|
1169 |
+
name = "scipy"
|
1170 |
+
version = "1.10.1"
|
1171 |
+
description = "Fundamental algorithms for scientific computing in Python"
|
1172 |
+
category = "main"
|
1173 |
+
optional = false
|
1174 |
+
python-versions = "<3.12,>=3.8"
|
1175 |
+
|
1176 |
+
[package.dependencies]
|
1177 |
+
numpy = ">=1.19.5,<1.27.0"
|
1178 |
+
|
1179 |
+
[package.extras]
|
1180 |
+
test = ["pytest", "pytest-cov", "pytest-timeout", "pytest-xdist", "asv", "mpmath", "gmpy2", "threadpoolctl", "scikit-umfpack", "pooch"]
|
1181 |
+
doc = ["sphinx (!=4.1.0)", "pydata-sphinx-theme (==0.9.0)", "sphinx-design (>=0.2.0)", "matplotlib (>2)", "numpydoc"]
|
1182 |
+
dev = ["mypy", "typing-extensions", "pycodestyle", "flake8", "rich-click", "click", "doit (>=0.36.0)", "pydevtool"]
|
1183 |
+
|
1184 |
+
[[package]]
|
1185 |
+
name = "seaborn"
|
1186 |
+
version = "0.12.2"
|
1187 |
+
description = "Statistical data visualization"
|
1188 |
+
category = "main"
|
1189 |
+
optional = false
|
1190 |
+
python-versions = ">=3.7"
|
1191 |
+
|
1192 |
+
[package.dependencies]
|
1193 |
+
matplotlib = ">=3.1,<3.6.1 || >3.6.1"
|
1194 |
+
numpy = ">=1.17,<1.24.0 || >1.24.0"
|
1195 |
+
pandas = ">=0.25"
|
1196 |
+
|
1197 |
+
[package.extras]
|
1198 |
+
dev = ["pytest", "pytest-cov", "pytest-xdist", "flake8", "mypy", "pandas-stubs", "pre-commit", "flit"]
|
1199 |
+
docs = ["numpydoc", "nbconvert", "ipykernel", "sphinx-copybutton", "sphinx-issues", "sphinx-design", "pyyaml", "pydata_sphinx_theme (==0.10.0rc2)"]
|
1200 |
+
stats = ["scipy (>=1.3)", "statsmodels (>=0.10)"]
|
1201 |
+
|
1202 |
+
[[package]]
|
1203 |
+
name = "semantic-version"
|
1204 |
+
version = "2.10.0"
|
1205 |
+
description = "A library implementing the 'SemVer' scheme."
|
1206 |
+
category = "main"
|
1207 |
+
optional = false
|
1208 |
+
python-versions = ">=2.7"
|
1209 |
+
|
1210 |
+
[package.extras]
|
1211 |
+
dev = ["Django (>=1.11)", "nose2", "tox", "check-manifest", "coverage", "flake8", "wheel", "zest.releaser", "readme-renderer (<25.0)", "colorama (<=0.4.1)"]
|
1212 |
+
doc = ["sphinx", "sphinx-rtd-theme"]
|
1213 |
+
|
1214 |
+
[[package]]
|
1215 |
+
name = "setuptools-scm"
|
1216 |
+
version = "7.1.0"
|
1217 |
+
description = "the blessed package to manage your versions by scm tags"
|
1218 |
+
category = "main"
|
1219 |
+
optional = false
|
1220 |
+
python-versions = ">=3.7"
|
1221 |
+
|
1222 |
+
[package.dependencies]
|
1223 |
+
packaging = ">=20.0"
|
1224 |
+
tomli = {version = ">=1.0.0", markers = "python_version < \"3.11\""}
|
1225 |
+
typing-extensions = "*"
|
1226 |
+
|
1227 |
+
[package.extras]
|
1228 |
+
test = ["pytest (>=6.2)", "virtualenv (>20)"]
|
1229 |
+
toml = ["setuptools (>=42)"]
|
1230 |
+
|
1231 |
+
[[package]]
|
1232 |
+
name = "six"
|
1233 |
+
version = "1.16.0"
|
1234 |
+
description = "Python 2 and 3 compatibility utilities"
|
1235 |
+
category = "main"
|
1236 |
+
optional = false
|
1237 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*"
|
1238 |
+
|
1239 |
+
[[package]]
|
1240 |
+
name = "sniffio"
|
1241 |
+
version = "1.3.0"
|
1242 |
+
description = "Sniff out which async library your code is running under"
|
1243 |
+
category = "main"
|
1244 |
+
optional = false
|
1245 |
+
python-versions = ">=3.7"
|
1246 |
+
|
1247 |
+
[[package]]
|
1248 |
+
name = "starlette"
|
1249 |
+
version = "0.26.1"
|
1250 |
+
description = "The little ASGI library that shines."
|
1251 |
+
category = "main"
|
1252 |
+
optional = false
|
1253 |
+
python-versions = ">=3.7"
|
1254 |
+
|
1255 |
+
[package.dependencies]
|
1256 |
+
anyio = ">=3.4.0,<5"
|
1257 |
+
|
1258 |
+
[package.extras]
|
1259 |
+
full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"]
|
1260 |
+
|
1261 |
+
[[package]]
|
1262 |
+
name = "tenacity"
|
1263 |
+
version = "8.2.2"
|
1264 |
+
description = "Retry code until it succeeds"
|
1265 |
+
category = "main"
|
1266 |
+
optional = false
|
1267 |
+
python-versions = ">=3.6"
|
1268 |
+
|
1269 |
+
[package.extras]
|
1270 |
+
doc = ["reno", "sphinx", "tornado (>=4.5)"]
|
1271 |
+
|
1272 |
+
[[package]]
|
1273 |
+
name = "tensorboard"
|
1274 |
+
version = "2.11.2"
|
1275 |
+
description = "TensorBoard lets you watch Tensors Flow"
|
1276 |
+
category = "main"
|
1277 |
+
optional = false
|
1278 |
+
python-versions = ">=3.7"
|
1279 |
+
|
1280 |
+
[package.dependencies]
|
1281 |
+
absl-py = ">=0.4"
|
1282 |
+
google-auth = ">=1.6.3,<3"
|
1283 |
+
google-auth-oauthlib = ">=0.4.1,<0.5"
|
1284 |
+
grpcio = ">=1.24.3"
|
1285 |
+
markdown = ">=2.6.8"
|
1286 |
+
numpy = ">=1.12.0"
|
1287 |
+
protobuf = ">=3.9.2,<4"
|
1288 |
+
requests = ">=2.21.0,<3"
|
1289 |
+
tensorboard-data-server = ">=0.6.0,<0.7.0"
|
1290 |
+
tensorboard-plugin-wit = ">=1.6.0"
|
1291 |
+
werkzeug = ">=1.0.1"
|
1292 |
+
|
1293 |
+
[[package]]
|
1294 |
+
name = "tensorboard-data-server"
|
1295 |
+
version = "0.6.1"
|
1296 |
+
description = "Fast data loading for TensorBoard"
|
1297 |
+
category = "main"
|
1298 |
+
optional = false
|
1299 |
+
python-versions = ">=3.6"
|
1300 |
+
|
1301 |
+
[[package]]
|
1302 |
+
name = "tensorboard-plugin-wit"
|
1303 |
+
version = "1.8.1"
|
1304 |
+
description = "What-If Tool TensorBoard plugin."
|
1305 |
+
category = "main"
|
1306 |
+
optional = false
|
1307 |
+
python-versions = "*"
|
1308 |
+
|
1309 |
+
[[package]]
|
1310 |
+
name = "tomli"
|
1311 |
+
version = "2.0.1"
|
1312 |
+
description = "A lil' TOML parser"
|
1313 |
+
category = "main"
|
1314 |
+
optional = false
|
1315 |
+
python-versions = ">=3.7"
|
1316 |
+
|
1317 |
+
[[package]]
|
1318 |
+
name = "toolz"
|
1319 |
+
version = "0.12.0"
|
1320 |
+
description = "List processing tools and functional utilities"
|
1321 |
+
category = "main"
|
1322 |
+
optional = false
|
1323 |
+
python-versions = ">=3.5"
|
1324 |
+
|
1325 |
+
[[package]]
|
1326 |
+
name = "torch"
|
1327 |
+
version = "1.13.1"
|
1328 |
+
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
|
1329 |
+
category = "main"
|
1330 |
+
optional = false
|
1331 |
+
python-versions = ">=3.7.0"
|
1332 |
+
|
1333 |
+
[package.dependencies]
|
1334 |
+
nvidia-cublas-cu11 = {version = "11.10.3.66", markers = "platform_system == \"Linux\""}
|
1335 |
+
nvidia-cuda-nvrtc-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
|
1336 |
+
nvidia-cuda-runtime-cu11 = {version = "11.7.99", markers = "platform_system == \"Linux\""}
|
1337 |
+
nvidia-cudnn-cu11 = {version = "8.5.0.96", markers = "platform_system == \"Linux\""}
|
1338 |
+
typing-extensions = "*"
|
1339 |
+
|
1340 |
+
[package.extras]
|
1341 |
+
opt-einsum = ["opt-einsum (>=3.3)"]
|
1342 |
+
|
1343 |
+
[[package]]
|
1344 |
+
name = "torchmetrics"
|
1345 |
+
version = "0.11.0"
|
1346 |
+
description = "PyTorch native Metrics"
|
1347 |
+
category = "main"
|
1348 |
+
optional = false
|
1349 |
+
python-versions = ">=3.7"
|
1350 |
+
|
1351 |
+
[package.dependencies]
|
1352 |
+
numpy = ">=1.17.2"
|
1353 |
+
packaging = "*"
|
1354 |
+
torch = ">=1.8.1"
|
1355 |
+
|
1356 |
+
[package.extras]
|
1357 |
+
all = ["pystoi", "torchvision (>=0.8)", "pycocotools", "scipy", "torch-fidelity", "torchvision", "lpips", "pytorch-lightning (>=1.5)", "transformers (>=4.10.0)", "regex (>=2021.9.24)", "nltk (>=3.6)", "tqdm (>=4.41.0)"]
|
1358 |
+
audio = ["pystoi"]
|
1359 |
+
detection = ["torchvision (>=0.8)", "pycocotools"]
|
1360 |
+
docs = ["sphinx-autodoc-typehints (>=1.0)", "nbsphinx (>=0.8)", "docutils (>=0.16)", "sphinx-togglebutton (>=0.2)", "pandoc (>=1.0)", "myst-parser", "sphinx-paramlinks (>=0.5.1)", "sphinxcontrib-fulltoc (>=1.0)", "sphinxcontrib-mockautodoc", "sphinx-copybutton (>=0.3)", "sphinx (>=4.0,<5.0)"]
|
1361 |
+
image = ["scipy", "torch-fidelity", "torchvision", "lpips"]
|
1362 |
+
integrate = ["pytorch-lightning (>=1.5)"]
|
1363 |
+
multimodal = ["transformers (>=4.10.0)"]
|
1364 |
+
test = ["types-protobuf", "rouge-score (>=0.0.4)", "bert-score (==0.3.10)", "requests", "mir-eval (>=0.6)", "jiwer (>=2.3.0)", "scikit-learn (>1.0,<1.1.1)", "check-manifest", "types-tabulate", "pytest-timeout", "types-emoji", "pycocotools", "coverage (>5.2)", "pytest (>=6.0.0,<7.0.0)", "types-six", "kornia (>=0.6.7)", "phmdoctest (>=1.1.1)", "pandas", "pytest-cov (>2.10)", "cloudpickle (>=1.3)", "pre-commit (>=1.0)", "scipy", "psutil", "mypy (==0.982)", "types-requests", "pytest-rerunfailures (>=10.0)", "types-pyyaml", "types-setuptools", "sacrebleu (>=2.0.0)", "netcal", "pytorch-msssim (==0.2.1)", "transformers (>4.4.0)", "fast-bss-eval (>=0.1.0)", "fire", "scikit-image (>0.17.1)", "dython", "torch-complex", "pytest-doctestplus (>=0.9.0)", "huggingface-hub (<0.7)", "pypesq (>1.2)"]
|
1365 |
+
text = ["regex (>=2021.9.24)", "nltk (>=3.6)", "tqdm (>=4.41.0)"]
|
1366 |
+
|
1367 |
+
[[package]]
|
1368 |
+
name = "torchvision"
|
1369 |
+
version = "0.14.1"
|
1370 |
+
description = "image and video datasets and models for torch deep learning"
|
1371 |
+
category = "main"
|
1372 |
+
optional = false
|
1373 |
+
python-versions = ">=3.7"
|
1374 |
+
|
1375 |
+
[package.dependencies]
|
1376 |
+
numpy = "*"
|
1377 |
+
pillow = ">=5.3.0,<8.3.0 || >=8.4.0"
|
1378 |
+
requests = "*"
|
1379 |
+
torch = "1.13.1"
|
1380 |
+
typing-extensions = "*"
|
1381 |
+
|
1382 |
+
[package.extras]
|
1383 |
+
scipy = ["scipy"]
|
1384 |
+
|
1385 |
+
[[package]]
|
1386 |
+
name = "tqdm"
|
1387 |
+
version = "4.65.0"
|
1388 |
+
description = "Fast, Extensible Progress Meter"
|
1389 |
+
category = "main"
|
1390 |
+
optional = false
|
1391 |
+
python-versions = ">=3.7"
|
1392 |
+
|
1393 |
+
[package.dependencies]
|
1394 |
+
colorama = {version = "*", markers = "platform_system == \"Windows\""}
|
1395 |
+
|
1396 |
+
[package.extras]
|
1397 |
+
dev = ["py-make (>=0.1.0)", "twine", "wheel"]
|
1398 |
+
notebook = ["ipywidgets (>=6)"]
|
1399 |
+
slack = ["slack-sdk"]
|
1400 |
+
telegram = ["requests"]
|
1401 |
+
|
1402 |
+
[[package]]
|
1403 |
+
name = "typing-extensions"
|
1404 |
+
version = "4.5.0"
|
1405 |
+
description = "Backported and Experimental Type Hints for Python 3.7+"
|
1406 |
+
category = "main"
|
1407 |
+
optional = false
|
1408 |
+
python-versions = ">=3.7"
|
1409 |
+
|
1410 |
+
[[package]]
|
1411 |
+
name = "tzdata"
|
1412 |
+
version = "2023.3"
|
1413 |
+
description = "Provider of IANA time zone data"
|
1414 |
+
category = "main"
|
1415 |
+
optional = false
|
1416 |
+
python-versions = ">=2"
|
1417 |
+
|
1418 |
+
[[package]]
|
1419 |
+
name = "uc-micro-py"
|
1420 |
+
version = "1.0.1"
|
1421 |
+
description = "Micro subset of unicode data files for linkify-it-py projects."
|
1422 |
+
category = "main"
|
1423 |
+
optional = false
|
1424 |
+
python-versions = ">=3.6"
|
1425 |
+
|
1426 |
+
[package.extras]
|
1427 |
+
test = ["coverage", "pytest", "pytest-cov"]
|
1428 |
+
|
1429 |
+
[[package]]
|
1430 |
+
name = "uritemplate"
|
1431 |
+
version = "4.1.1"
|
1432 |
+
description = "Implementation of RFC 6570 URI Templates"
|
1433 |
+
category = "main"
|
1434 |
+
optional = false
|
1435 |
+
python-versions = ">=3.6"
|
1436 |
+
|
1437 |
+
[[package]]
|
1438 |
+
name = "urllib3"
|
1439 |
+
version = "1.26.15"
|
1440 |
+
description = "HTTP library with thread-safe connection pooling, file post, and more."
|
1441 |
+
category = "main"
|
1442 |
+
optional = false
|
1443 |
+
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*"
|
1444 |
+
|
1445 |
+
[package.extras]
|
1446 |
+
brotli = ["brotlicffi (>=0.8.0)", "brotli (>=1.0.9)", "brotlipy (>=0.6.0)"]
|
1447 |
+
secure = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "certifi", "urllib3-secure-extra", "ipaddress"]
|
1448 |
+
socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"]
|
1449 |
+
|
1450 |
+
[[package]]
|
1451 |
+
name = "uvicorn"
|
1452 |
+
version = "0.21.1"
|
1453 |
+
description = "The lightning-fast ASGI server."
|
1454 |
+
category = "main"
|
1455 |
+
optional = false
|
1456 |
+
python-versions = ">=3.7"
|
1457 |
+
|
1458 |
+
[package.dependencies]
|
1459 |
+
click = ">=7.0"
|
1460 |
+
h11 = ">=0.8"
|
1461 |
+
|
1462 |
+
[package.extras]
|
1463 |
+
standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"]
|
1464 |
+
|
1465 |
+
[[package]]
|
1466 |
+
name = "websockets"
|
1467 |
+
version = "11.0.1"
|
1468 |
+
description = "An implementation of the WebSocket Protocol (RFC 6455 & 7692)"
|
1469 |
+
category = "main"
|
1470 |
+
optional = false
|
1471 |
+
python-versions = ">=3.7"
|
1472 |
+
|
1473 |
+
[[package]]
|
1474 |
+
name = "werkzeug"
|
1475 |
+
version = "2.2.3"
|
1476 |
+
description = "The comprehensive WSGI web application library."
|
1477 |
+
category = "main"
|
1478 |
+
optional = false
|
1479 |
+
python-versions = ">=3.7"
|
1480 |
+
|
1481 |
+
[package.dependencies]
|
1482 |
+
MarkupSafe = ">=2.1.1"
|
1483 |
+
|
1484 |
+
[package.extras]
|
1485 |
+
watchdog = ["watchdog"]
|
1486 |
+
|
1487 |
+
[[package]]
|
1488 |
+
name = "wget"
|
1489 |
+
version = "3.2"
|
1490 |
+
description = "pure python download utility"
|
1491 |
+
category = "main"
|
1492 |
+
optional = false
|
1493 |
+
python-versions = "*"
|
1494 |
+
|
1495 |
+
[[package]]
|
1496 |
+
name = "yarl"
|
1497 |
+
version = "1.8.2"
|
1498 |
+
description = "Yet another URL library"
|
1499 |
+
category = "main"
|
1500 |
+
optional = false
|
1501 |
+
python-versions = ">=3.7"
|
1502 |
+
|
1503 |
+
[package.dependencies]
|
1504 |
+
idna = ">=2.0"
|
1505 |
+
multidict = ">=4.0"
|
1506 |
+
|
1507 |
+
[metadata]
|
1508 |
+
lock-version = "1.1"
|
1509 |
+
python-versions = ">=3.10,<3.12"
|
1510 |
+
content-hash = "17cec1f61fed3b070c0b744eeecc9dbaed1ea06d758238ac84f108545ab14a21"
|
1511 |
+
|
1512 |
+
[metadata.files]
|
1513 |
+
absl-py = []
|
1514 |
+
aiofiles = []
|
1515 |
+
aiohttp = []
|
1516 |
+
aiosignal = []
|
1517 |
+
altair = []
|
1518 |
+
antlr4-python3-runtime = []
|
1519 |
+
anyio = []
|
1520 |
+
async-timeout = []
|
1521 |
+
attrs = []
|
1522 |
+
cachetools = []
|
1523 |
+
certifi = []
|
1524 |
+
charset-normalizer = []
|
1525 |
+
click = []
|
1526 |
+
colorama = []
|
1527 |
+
contourpy = []
|
1528 |
+
cycler = []
|
1529 |
+
earthengine-api = []
|
1530 |
+
ee-extra = []
|
1531 |
+
entrypoints = []
|
1532 |
+
fastapi = []
|
1533 |
+
ffmpy = []
|
1534 |
+
filelock = []
|
1535 |
+
fonttools = []
|
1536 |
+
frozenlist = []
|
1537 |
+
fsspec = []
|
1538 |
+
google-api-core = []
|
1539 |
+
google-api-python-client = []
|
1540 |
+
google-auth = []
|
1541 |
+
google-auth-httplib2 = []
|
1542 |
+
google-auth-oauthlib = []
|
1543 |
+
google-cloud-core = []
|
1544 |
+
google-cloud-storage = []
|
1545 |
+
google-crc32c = []
|
1546 |
+
google-resumable-media = []
|
1547 |
+
googleapis-common-protos = []
|
1548 |
+
gradio = []
|
1549 |
+
gradio-client = []
|
1550 |
+
grpcio = []
|
1551 |
+
h11 = []
|
1552 |
+
httpcore = []
|
1553 |
+
httplib2 = []
|
1554 |
+
httpx = []
|
1555 |
+
huggingface-hub = []
|
1556 |
+
hydra-client = []
|
1557 |
+
hydra-core = []
|
1558 |
+
idna = []
|
1559 |
+
jinja2 = []
|
1560 |
+
jsonschema = []
|
1561 |
+
kiwisolver = []
|
1562 |
+
lightning-utilities = []
|
1563 |
+
linkify-it-py = []
|
1564 |
+
markdown = []
|
1565 |
+
markdown-it-py = []
|
1566 |
+
markupsafe = []
|
1567 |
+
matplotlib = []
|
1568 |
+
mdit-py-plugins = []
|
1569 |
+
mdurl = []
|
1570 |
+
multidict = []
|
1571 |
+
numpy = []
|
1572 |
+
nvidia-cublas-cu11 = []
|
1573 |
+
nvidia-cuda-nvrtc-cu11 = []
|
1574 |
+
nvidia-cuda-runtime-cu11 = []
|
1575 |
+
nvidia-cudnn-cu11 = []
|
1576 |
+
oauthlib = []
|
1577 |
+
omegaconf = []
|
1578 |
+
opencv-python = []
|
1579 |
+
orjson = []
|
1580 |
+
packaging = []
|
1581 |
+
pandas = []
|
1582 |
+
pillow = []
|
1583 |
+
plotly = []
|
1584 |
+
protobuf = []
|
1585 |
+
pyasn1 = []
|
1586 |
+
pyasn1-modules = []
|
1587 |
+
pydantic = []
|
1588 |
+
pydub = []
|
1589 |
+
pyparsing = []
|
1590 |
+
pyrsistent = []
|
1591 |
+
python-dateutil = []
|
1592 |
+
python-multipart = []
|
1593 |
+
pytorch-lightning = []
|
1594 |
+
pytz = []
|
1595 |
+
pyyaml = []
|
1596 |
+
requests = []
|
1597 |
+
requests-oauthlib = []
|
1598 |
+
rsa = []
|
1599 |
+
scipy = []
|
1600 |
+
seaborn = []
|
1601 |
+
semantic-version = []
|
1602 |
+
setuptools-scm = []
|
1603 |
+
six = []
|
1604 |
+
sniffio = []
|
1605 |
+
starlette = []
|
1606 |
+
tenacity = []
|
1607 |
+
tensorboard = []
|
1608 |
+
tensorboard-data-server = []
|
1609 |
+
tensorboard-plugin-wit = []
|
1610 |
+
tomli = []
|
1611 |
+
toolz = []
|
1612 |
+
torch = []
|
1613 |
+
torchmetrics = []
|
1614 |
+
torchvision = []
|
1615 |
+
tqdm = []
|
1616 |
+
typing-extensions = []
|
1617 |
+
tzdata = []
|
1618 |
+
uc-micro-py = []
|
1619 |
+
uritemplate = []
|
1620 |
+
urllib3 = []
|
1621 |
+
uvicorn = []
|
1622 |
+
websockets = []
|
1623 |
+
werkzeug = []
|
1624 |
+
wget = []
|
1625 |
+
yarl = []
|
pyproject.toml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "cv_app"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = ""
|
5 |
+
authors = ["Your Name <[email protected]>"]
|
6 |
+
|
7 |
+
[tool.poetry.dependencies]
|
8 |
+
python = ">=3.10,<3.12"
|
9 |
+
torch = "1.13.1"
|
10 |
+
tensorboard = "2.11.2"
|
11 |
+
pytorch-lightning = "1.9.0"
|
12 |
+
torchmetrics = "0.11.0"
|
13 |
+
Pillow = "8.4.0"
|
14 |
+
torchvision = "0.14.1"
|
15 |
+
matplotlib = "^3.7.1"
|
16 |
+
hydra-client = "0.5.1"
|
17 |
+
hydra-core = "1.3.1"
|
18 |
+
wget = "^3.2"
|
19 |
+
scipy = "^1.10.1"
|
20 |
+
seaborn = "^0.12.2"
|
21 |
+
earthengine-api = "0.1.338"
|
22 |
+
ee-extra = "0.0.15"
|
23 |
+
gradio = "^3.27.0"
|
24 |
+
opencv-python = "^4.7.0"
|
25 |
+
plotly = "^5.14.1"
|
26 |
+
|
27 |
+
[tool.poetry.dev-dependencies]
|
28 |
+
|
29 |
+
[build-system]
|
30 |
+
requires = ["poetry-core>=1.0.0"]
|
31 |
+
build-backend = "poetry.core.masonry.api"
|
requirements.txt
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.3.1
|
4 |
+
antlr4-python3-runtime==4.9.3
|
5 |
+
appdirs==1.4.4
|
6 |
+
argh==0.26.2
|
7 |
+
async-timeout==4.0.2
|
8 |
+
atomicwrites==1.4.0
|
9 |
+
attrs==19.3.0
|
10 |
+
backports.weakref==1.0.post1
|
11 |
+
bkcharts==0.2
|
12 |
+
black==19.10b0
|
13 |
+
boto==2.49.0
|
14 |
+
bqplot==0.12.36
|
15 |
+
branca==0.6.0
|
16 |
+
brotlipy==0.7.0
|
17 |
+
cachetools==5.3.0
|
18 |
+
certifi==2021.10.8
|
19 |
+
click==8.0.3
|
20 |
+
colour==0.1.5
|
21 |
+
comtypes==1.1.10
|
22 |
+
cycler==0.10.0
|
23 |
+
cytoolz==0.11.0
|
24 |
+
daal4py==2021.3.0
|
25 |
+
dask==2021.10.0
|
26 |
+
earthengine-api==0.1.338
|
27 |
+
ee-extra==0.0.15
|
28 |
+
eerepr==0.0.4
|
29 |
+
entrypoints==0.3
|
30 |
+
et-xmlfile==1.1.0
|
31 |
+
export==0.2.0
|
32 |
+
ffmpeg-python==0.2.0
|
33 |
+
folium==0.14.0
|
34 |
+
fonttools==4.25.0
|
35 |
+
frozenlist==1.3.3
|
36 |
+
gdown==4.6.0
|
37 |
+
geeadd==0.5.6
|
38 |
+
geemap==0.19.6
|
39 |
+
geocoder==1.38.1
|
40 |
+
geojson==3.0.0
|
41 |
+
google-api-core==2.11.0
|
42 |
+
google-api-python-client==2.74.0
|
43 |
+
google-auth==2.16.0
|
44 |
+
google-auth-httplib2==0.1.0
|
45 |
+
google-auth-oauthlib==0.4.6
|
46 |
+
google-cloud-core==2.3.2
|
47 |
+
google-cloud-storage==2.7.0
|
48 |
+
google-crc32c==1.5.0
|
49 |
+
google-resumable-media==2.4.1
|
50 |
+
googleapis-common-protos==1.58.0
|
51 |
+
grpcio==1.51.1
|
52 |
+
httplib2==0.21.0
|
53 |
+
hydra-client==0.5.1
|
54 |
+
hydra-core==1.3.1
|
55 |
+
inflection==0.5.1
|
56 |
+
ipyevents==2.0.1
|
57 |
+
ipyfilechooser==0.6.0
|
58 |
+
ipyleaflet==0.17.2
|
59 |
+
ipytree==0.2.2
|
60 |
+
lightning-utilities==0.6.0.post0
|
61 |
+
llvmlite==0.37.0
|
62 |
+
locket==0.2.1
|
63 |
+
logzero==1.7.0
|
64 |
+
Markdown==3.4.1
|
65 |
+
mccabe==0.6.1
|
66 |
+
mkl-fft==1.3.1
|
67 |
+
mkl-service==2.4.0
|
68 |
+
mpmath==1.2.1
|
69 |
+
multidict==6.0.4
|
70 |
+
munkres==1.1.4
|
71 |
+
mypy-extensions==0.4.3
|
72 |
+
nltk==3.6.5
|
73 |
+
oauthlib==3.2.2
|
74 |
+
omegaconf==2.3.0
|
75 |
+
pathspec==0.7.0
|
76 |
+
patsy==0.5.2
|
77 |
+
pep8==1.7.1
|
78 |
+
Pillow==8.4.0
|
79 |
+
pkginfo==1.7.1
|
80 |
+
plotly==5.13.0
|
81 |
+
ply==3.11
|
82 |
+
protobuf==3.20.3
|
83 |
+
pyasn1==0.4.8
|
84 |
+
pyasn1-modules==0.2.8
|
85 |
+
pycosat==0.6.3
|
86 |
+
PyCRS==1.0.2
|
87 |
+
pycurl==7.44.1
|
88 |
+
pyls-spyder==0.4.0
|
89 |
+
pyperclip==1.8.2
|
90 |
+
pyreadline==2.1
|
91 |
+
pyshp==2.3.1
|
92 |
+
pytest==6.2.4
|
93 |
+
python-box==6.1.0
|
94 |
+
python-lsp-jsonrpc==1.0.0
|
95 |
+
python-lsp-server==1.2.4
|
96 |
+
pytorch-lightning==1.9.0
|
97 |
+
pytz==2021.3
|
98 |
+
PyYAML==6.0
|
99 |
+
ratelim==0.1.6
|
100 |
+
requests-oauthlib==1.3.1
|
101 |
+
rsa==4.9
|
102 |
+
sankee==0.2.1
|
103 |
+
scikit-image==0.18.3
|
104 |
+
scooby==0.7.1
|
105 |
+
simplegeneric==0.8.1
|
106 |
+
Sphinx==4.2.0
|
107 |
+
statsmodels==0.12.2
|
108 |
+
tables==3.6.1
|
109 |
+
tenacity==8.1.0
|
110 |
+
tensorboard==2.11.2
|
111 |
+
tensorboard-data-server==0.6.1
|
112 |
+
tensorboard-plugin-wit==1.8.1
|
113 |
+
terminado==0.9.4
|
114 |
+
torch==1.13.1
|
115 |
+
torchaudio==0.13.1
|
116 |
+
torchmetrics==0.11.0
|
117 |
+
torchvision==0.14.1
|
118 |
+
traittypes==0.2.1
|
119 |
+
typing_extensions==4.4.0
|
120 |
+
unicodecsv==0.14.1
|
121 |
+
uritemplate==4.1.1
|
122 |
+
urllib3==1.26.7
|
123 |
+
webencodings==0.5.1
|
124 |
+
wget==3.2
|
125 |
+
whitebox==2.2.0
|
126 |
+
whiteboxgui==2.2.0
|
127 |
+
win-unicode-console==0.5
|
128 |
+
wincertstore==0.2
|
129 |
+
xlwt==1.3.0
|
130 |
+
xyzservices==2022.9.0
|
131 |
+
yarl==1.8.2
|
132 |
+
zict==2.0.0
|
133 |
+
zope.event==4.5.0
|