mikonvergence commited on
Commit
5318c78
·
1 Parent(s): 841fede

front end ready

Browse files
Files changed (2) hide show
  1. app.py +69 -0
  2. src/utils.py +170 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from src.utils import *
3
+
4
+ theme = gr.themes.Soft(primary_hue="amber", secondary_hue="orange", font=[gr.themes.GoogleFont("Source Sans 3", weights=(400, 600)),'arial'])
5
+
6
+ with gr.Blocks(theme=theme) as demo:
7
+ with gr.Column(elem_classes="header"):
8
+ gr.Markdown("# 🗾 COP-GEN-Beta: Unified Generative Modelling of COPernicus Imagery Thumbnails")
9
+ gr.Markdown("### Miguel Espinosa, Valerio Marsocci, Yuru Jia, Elliot J. Crowley, Mikolaj Czerkawski")
10
+ gr.Markdown('[[Website](https://miquel-espinosa.github.io/cop-gen-beta/)] [[GitHub](https://github.com/miquel-espinosa/COP-GEN-Beta)] [[Model](https://huggingface.co/mespinosami/COP-GEN-Beta)] [[Dataset](https://huggingface.co/Major-TOM)]')
11
+
12
+ with gr.Column(elem_classes="abstract"):
13
+
14
+ with gr.Accordion("Abstract", open=False) as abstract:
15
+ gr.Markdown("In remote sensing, multi-modal data from various sensors capturing the same scene offers rich opportunities, but learning a unified representation across these modalities remains a significant challenge. Traditional methods have often been limited to single or dual-modality approaches. In this paper, we introduce COP-GEN-Beta, a generative diffusion model trained on optical, radar, and elevation data from the Major TOM dataset. What sets COP-GEN-Beta apart is its ability to map any subset of modalities to any other, enabling zero-shot modality translation after training. This is achieved through a sequence-based diffusion transformer, where each modality is controlled by its own timestep embedding. We extensively evaluate COP-GEN-Beta on thumbnail images from the Major TOM dataset, demonstrating its effectiveness in generating high-quality samples. Qualitative and quantitative evaluations validate the model's performance, highlighting its potential as a powerful pre-trained model for future remote sensing tasks.") # Replace with your abstract text
16
+
17
+ with gr.Accordion("Instructions", open=False) as abstract:
18
+ gr.Markdown("1. **Define input**: You can upload your thumbnails manually or you can get a random sample from Major TOM by clicking the button.")
19
+ gr.Markdown("2. **Select conditions**: Each input image can be used as a **conditioning** by selecting the `Active` checkbox. If no checkbox is selected, then you will observe **unconditional generation**.")
20
+ gr.Markdown("3. **Generate**: Click the `Generate` button to synthesize the output. The outputs will be shown below.")
21
+
22
+ with gr.Column():
23
+ with gr.Row():
24
+ gr.Markdown("## Inputs (Optional)")
25
+ load_button = gr.Button("Load a random sample from Major TOM 🗺", variant="secondary")
26
+ with gr.Row():
27
+ with gr.Column():
28
+ s2l1c_input = gr.Image(label="S2 L1C (Optical - Top of Atmosphere)", interactive=True)
29
+ s2l1c_active = gr.Checkbox(value=False, label="Active", interactive=True)
30
+ with gr.Column():
31
+ s2l2a_input = gr.Image(label="S2 L2A (Optical - Bottom of Atmosphere)", interactive=True)
32
+ s2l2a_active = gr.Checkbox(value=False, label="Active", interactive=True)
33
+ with gr.Column():
34
+ s1rtc_input = gr.Image(label="S1 RTC (SAR)", interactive=True)
35
+ s1rtc_active = gr.Checkbox(value=False, label="Active", interactive=True)
36
+ with gr.Column():
37
+ dem_input = gr.Image(label="DEM (Elevation)", interactive=True)
38
+ dem_active = gr.Checkbox(value=False, label="Active", interactive=True)
39
+
40
+ generate_button = gr.Button("Generate", variant="primary")
41
+
42
+ gr.Markdown("## Outputs")
43
+ with gr.Row():
44
+ s2l1c_output = gr.Image(label="S2 L1C (Optical - Top of Atmosphere)", interactive=False)
45
+ s2l2a_output = gr.Image(label="S2 L2A (Optical - Bottom of Atmosphere)", interactive=False)
46
+ s1rtc_output = gr.Image(label="S1 RTC (SAR)", interactive=False)
47
+ dem_output = gr.Image(label="DEM (Elevation)", interactive=False)
48
+
49
+ with gr.Accordion("Advanced Options", open=False) as advanced_options:
50
+ num_inference_steps_slider = gr.Slider(minimum=10, maximum=1000, step=10, value=50, label="Inference Steps")
51
+ guidance_scale_slider = gr.Slider(minimum=1.0, maximum=15.0, step=0.5, value=7.5, label="Guidance Scale")
52
+ with gr.Row():
53
+ seed_number = gr.Number(value=6378, label="Seed")
54
+ seed_checkbox = gr.Checkbox(value=True, label="Random")
55
+
56
+ load_button.click(
57
+ fn=sample_shuffle,
58
+ outputs=[s2l1c_input, s2l1c_active, s2l2a_input,s2l2a_active, s1rtc_input, s1rtc_active, dem_input, dem_active]
59
+ )
60
+
61
+ generate_button.click(
62
+ #fn=generate_output,
63
+ inputs=[s2l1c_input, s2l2a_input, s1rtc_input, dem_input, num_inference_steps_slider, guidance_scale_slider, seed_number, seed_checkbox],
64
+ outputs=[s2l1c_output, s2l2a_output, s1rtc_output, dem_output],
65
+ )
66
+
67
+ demo.launch()
68
+
69
+ demo.launch(share=True)
src/utils.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ # GLOBAL VARIABLES
5
+ if os.path.isfile('data/s2l2a_metadata.parquet'):
6
+ l2a_meta_path = 'data/s2l2a_metadata.parquet'
7
+ else:
8
+ DATASET_NAME = 'Major-TOM/Core-S2L2A'
9
+ l2a_meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
10
+
11
+ if os.path.isfile('data/s2l1c_metadata.parquet'):
12
+ l1c_meta_path = 'data/s2l1c_metadata.parquet'
13
+ else:
14
+ DATASET_NAME = 'Major-TOM/Core-S2L1C'
15
+ l1c_meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
16
+
17
+ if os.path.isfile('/s1rtc_metadata.parquet'):
18
+ rtc_meta_path = 'data/s1rtc_metadata.parquet'
19
+ else:
20
+ DATASET_NAME = 'Major-TOM/Core-S1RTC'
21
+ rtc_meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
22
+
23
+ if os.path.isfile('helpers/dem_metadata.parquet'):
24
+ dem_meta_path = 'data/dem_metadata.parquet'
25
+ else:
26
+ DATASET_NAME = 'Major-TOM/Core-DEM'
27
+ dem_meta_path = 'https://huggingface.co/datasets/{}/resolve/main/metadata.parquet'.format(DATASET_NAME)
28
+
29
+ print('Loading Major TOM meta...')
30
+ l2a_df = pd.read_parquet(l2a_meta_path)
31
+ l1c_df = pd.read_parquet(l1c_meta_path)
32
+ rtc_df = pd.read_parquet(rtc_meta_path)
33
+ dem_df = pd.read_parquet(dem_meta_path)
34
+
35
+ # skip files with missing parts
36
+ l2a_df = l2a_df[l2a_df.nodata == 0]
37
+ l1c_df = l1c_df[l1c_df.nodata == 0]
38
+ rtc_df = rtc_df[rtc_df.nodata == 0]
39
+ dem_df = dem_df[dem_df.nodata == 0]
40
+
41
+ # collect grid_cells, drop duplicates, and extract grid cell column only
42
+ grid_cell_df = l2a_df[l2a_df.grid_cell.isin(l1c_df.grid_cell) &l2a_df.grid_cell.isin(rtc_df.grid_cell) & l2a_df.grid_cell.isin(dem_df.grid_cell)]
43
+ gird_cell_df = grid_cell_df.drop_duplicates(subset=['grid_cell'])
44
+ grid_cell_df = grid_cell_df.grid_cell
45
+ print('[DONE]')
46
+
47
+ import pyarrow.parquet as pq
48
+ import fsspec
49
+ from fsspec.parquet import open_parquet_file
50
+ from io import BytesIO
51
+ from PIL import Image
52
+ import random
53
+
54
+ def row2image(row, fullrow_read=True):
55
+ """
56
+ Extracts an image from a specific row in a Parquet file.
57
+
58
+ Args:
59
+ row: A row object containing information about the Parquet file and row index.
60
+ It is expected to have attributes 'parquet_row' (the row index within the Parquet file)
61
+ and 'parquet_url' (the URL or path to the Parquet file).
62
+ fullrow_read (bool, optional): Determines whether to read the entire Parquet file or just the 'thumbnail' column initially.
63
+ Defaults to True.
64
+ - If True, it opens the Parquet file using fsspec and reads the entire file.
65
+ - If False, it uses fsspec.parquet.open_parquet_file to only open the 'thumbnail' column.
66
+
67
+ Returns:
68
+ PIL.Image.Image: An Image object loaded from the 'thumbnail' data in the specified row.
69
+ """
70
+ parquet_row = row.parquet_row
71
+ parquet_url = row.parquet_url
72
+
73
+ if fullrow_read:
74
+ # Option 1: Read the entire Parquet file
75
+ f = fsspec.open(parquet_url)
76
+ temp_path = f.open()
77
+ else:
78
+ # Option 2: Read only the 'thumbnail' column initially
79
+ temp_path = open_parquet_file(parquet_url, columns=["thumbnail"])
80
+
81
+ with pq.ParquetFile(temp_path) as pf:
82
+ first_row_group = pf.read_row_group(parquet_row, columns=['thumbnail'])
83
+
84
+ stream = BytesIO(first_row_group['thumbnail'][0].as_py())
85
+ return Image.open(stream)
86
+
87
+ # Example usage (assuming 'dem_df' is a Pandas DataFrame with the required structure):
88
+ # row2image(dem_df.iloc[1000])
89
+
90
+ def get_rows(grid_cell):
91
+ """
92
+ Retrieves the first row from multiple DataFrames based on a given 'grid_cell' value.
93
+
94
+ Args:
95
+ grid_cell: The value to filter the DataFrames by in the 'grid_cell' column.
96
+
97
+ Returns:
98
+ tuple: A tuple containing the first matching row from each of the following DataFrames:
99
+ l2a_df, l1c_df, rtc_df, and dem_df. It assumes these DataFrames are defined in the scope.
100
+ Each element of the tuple is a Pandas Series representing a row.
101
+ """
102
+ return l2a_df[l2a_df.grid_cell == grid_cell].iloc[0], \
103
+ l1c_df[l1c_df.grid_cell == grid_cell].iloc[0], \
104
+ rtc_df[rtc_df.grid_cell == grid_cell].iloc[0], \
105
+ dem_df[dem_df.grid_cell == grid_cell].iloc[0]
106
+
107
+ def get_images(grid_cell):
108
+ """
109
+ Retrieves images corresponding to a specific 'grid_cell' by calling get_rows and row2image.
110
+
111
+ Args:
112
+ grid_cell: The grid cell identifier to fetch images for.
113
+
114
+ Returns:
115
+ list: A list of PIL.Image.Image objects, where each image is extracted from the rows
116
+ returned by the get_rows function for the given grid cell.
117
+ """
118
+ img_rows = get_rows(grid_cell)
119
+
120
+ imgs = []
121
+ for row in img_rows:
122
+ imgs.append(row2image(row))
123
+
124
+ return imgs
125
+
126
+ def resize_and_crop(images, image_size=(1068, 1068), crop_size=(256, 256)):
127
+ """
128
+ Resizes a list of images to a specified size and then crops a random portion from each.
129
+
130
+ Args:
131
+ images (list): A list of PIL.Image.Image objects to be processed.
132
+ image_size (tuple, optional): The target size (width, height) to resize the images to.
133
+ Defaults to (1068, 1068).
134
+ crop_size (tuple, optional): The size (width, height) of the random crop to be taken
135
+ from the resized images. Defaults to (256, 256).
136
+
137
+ Returns:
138
+ list: A list of PIL.Image.Image objects, where each image has been resized and then cropped.
139
+ """
140
+ left = random.randint(0, image_size[0] - crop_size[0])
141
+ top = random.randint(0, image_size[1] - crop_size[1])
142
+ right = left + crop_size[0]
143
+ bottom = top + crop_size[1]
144
+
145
+ return [img.resize(image_size).crop((left, top, right, bottom)) for img in images]
146
+
147
+ def sample_shuffle(interface=True):
148
+ """
149
+ Randomly selects a 'grid_cell', retrieves corresponding images, and optionally prepares them for an interface.
150
+
151
+ Args:
152
+ interface (bool, optional): If True, the function returns a list where each image is followed by True.
153
+ This might be intended for an interface that expects an image and a boolean flag.
154
+ If False, it returns just the list of processed images. Defaults to True.
155
+
156
+ Returns:
157
+ list: If interface is False, returns a list of resized and cropped PIL.Image.Image objects.
158
+ If interface is True, returns a list where each image is followed by the boolean value True.
159
+ """
160
+ grid_cell = grid_cell_df.sample().iloc[0]
161
+
162
+ images = resize_and_crop(get_images(grid_cell))
163
+
164
+ if not interface:
165
+ return images
166
+ else:
167
+ out = []
168
+ for el in images:
169
+ out += [el, True]
170
+ return out