Spaces:
Runtime error
Runtime error
new `tensor_to_uint8_numpy_image` tensor util
Browse files- climategan/trainer.py +10 -16
- climategan/tutils.py +46 -4
climategan/trainer.py
CHANGED
|
@@ -39,10 +39,10 @@ from climategan.tutils import (
|
|
| 39 |
get_WGAN_gradient,
|
| 40 |
lrgb2srgb,
|
| 41 |
normalize,
|
| 42 |
-
normalize_tensor,
|
| 43 |
print_num_parameters,
|
| 44 |
shuffle_batch_tuple,
|
| 45 |
srgb2lrgb,
|
|
|
|
| 46 |
vgg_preprocess,
|
| 47 |
zero_grad,
|
| 48 |
)
|
|
@@ -231,12 +231,15 @@ class Trainer:
|
|
| 231 |
return_intermediates=False,
|
| 232 |
):
|
| 233 |
"""
|
| 234 |
-
Create a
|
| 235 |
single or batch image data.
|
| 236 |
|
| 237 |
-
stores is a
|
| 238 |
|
| 239 |
bin_value is used to binarize (or not) flood masks
|
|
|
|
|
|
|
|
|
|
| 240 |
"""
|
| 241 |
assert self.is_setup
|
| 242 |
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
|
|
@@ -316,21 +319,14 @@ class Trainer:
|
|
| 316 |
with Timer(store=stores.get("numpy", [])):
|
| 317 |
if "flood" not in ignore_event:
|
| 318 |
# normalize to 0-1
|
| 319 |
-
flood =
|
| 320 |
-
# convert to numpy
|
| 321 |
-
flood = flood.permute(0, 2, 3, 1).numpy()
|
| 322 |
# convert to 0-255 uint8
|
| 323 |
-
flood = (flood * 255).astype(np.uint8)
|
| 324 |
output_data["flood"] = flood
|
| 325 |
if "wildfire" not in ignore_event:
|
| 326 |
-
wildfire =
|
| 327 |
-
wildfire = wildfire.permute(0, 2, 3, 1).numpy()
|
| 328 |
-
wildfire = (wildfire * 255).astype(np.uint8)
|
| 329 |
output_data["wildfire"] = wildfire
|
| 330 |
if "smog" not in ignore_event:
|
| 331 |
-
smog =
|
| 332 |
-
smog = smog.permute(0, 2, 3, 1).numpy()
|
| 333 |
-
smog = (smog * 255).astype(np.uint8)
|
| 334 |
output_data["smog"] = smog
|
| 335 |
|
| 336 |
if return_intermediates:
|
|
@@ -338,9 +334,7 @@ class Trainer:
|
|
| 338 |
output_data["mask"] = (
|
| 339 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
| 340 |
)
|
| 341 |
-
output_data["depth"] = (
|
| 342 |
-
normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
|
| 343 |
-
)
|
| 344 |
output_data["segmentation"] = (
|
| 345 |
decode_segmap_merged_labels(segmentation, "r", False)
|
| 346 |
.cpu()
|
|
|
|
| 39 |
get_WGAN_gradient,
|
| 40 |
lrgb2srgb,
|
| 41 |
normalize,
|
|
|
|
| 42 |
print_num_parameters,
|
| 43 |
shuffle_batch_tuple,
|
| 44 |
srgb2lrgb,
|
| 45 |
+
tensor_to_uint8_numpy_image,
|
| 46 |
vgg_preprocess,
|
| 47 |
zero_grad,
|
| 48 |
)
|
|
|
|
| 231 |
return_intermediates=False,
|
| 232 |
):
|
| 233 |
"""
|
| 234 |
+
Create a dictionary of events from a numpy or tensor,
|
| 235 |
single or batch image data.
|
| 236 |
|
| 237 |
+
stores is a dictionary of times for the Timer class.
|
| 238 |
|
| 239 |
bin_value is used to binarize (or not) flood masks
|
| 240 |
+
|
| 241 |
+
all values in the output dictionary have 4 dimensions:
|
| 242 |
+
BxHxWxC if numpy else BxCxHxW
|
| 243 |
"""
|
| 244 |
assert self.is_setup
|
| 245 |
assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
|
|
|
|
| 319 |
with Timer(store=stores.get("numpy", [])):
|
| 320 |
if "flood" not in ignore_event:
|
| 321 |
# normalize to 0-1
|
| 322 |
+
flood = tensor_to_uint8_numpy_image(flood)
|
|
|
|
|
|
|
| 323 |
# convert to 0-255 uint8
|
|
|
|
| 324 |
output_data["flood"] = flood
|
| 325 |
if "wildfire" not in ignore_event:
|
| 326 |
+
wildfire = tensor_to_uint8_numpy_image(wildfire)
|
|
|
|
|
|
|
| 327 |
output_data["wildfire"] = wildfire
|
| 328 |
if "smog" not in ignore_event:
|
| 329 |
+
smog = tensor_to_uint8_numpy_image(smog)
|
|
|
|
|
|
|
| 330 |
output_data["smog"] = smog
|
| 331 |
|
| 332 |
if return_intermediates:
|
|
|
|
| 334 |
output_data["mask"] = (
|
| 335 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
| 336 |
)
|
| 337 |
+
output_data["depth"] = tensor_to_uint8_numpy_image(depth)
|
|
|
|
|
|
|
| 338 |
output_data["segmentation"] = (
|
| 339 |
decode_segmap_merged_labels(segmentation, "r", False)
|
| 340 |
.cpu()
|
climategan/tutils.py
CHANGED
|
@@ -564,14 +564,29 @@ def lrgb2srgb(ims):
|
|
| 564 |
return outs[0]
|
| 565 |
|
| 566 |
|
| 567 |
-
def normalize(t, mini=0, maxi=1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 568 |
if len(t.shape) == 3:
|
| 569 |
return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
|
| 570 |
|
| 571 |
batch_size = t.shape[0]
|
| 572 |
-
|
|
|
|
| 573 |
t = t - min_t
|
| 574 |
-
max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size,
|
| 575 |
t = t / max_t
|
| 576 |
return mini + (maxi - mini) * t
|
| 577 |
|
|
@@ -644,7 +659,7 @@ def write_architecture(trainer):
|
|
| 644 |
f.write(output)
|
| 645 |
|
| 646 |
|
| 647 |
-
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t
|
| 648 |
delta = (res[0] / shape[0], res[1] / shape[1])
|
| 649 |
d = (shape[0] // res[0], shape[1] // res[1])
|
| 650 |
|
|
@@ -719,3 +734,30 @@ def tensor_ims_to_np_uint8s(ims):
|
|
| 719 |
nps.append(n.astype(np.uint8))
|
| 720 |
|
| 721 |
return nps[0] if len(nps) == 1 else nps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
return outs[0]
|
| 565 |
|
| 566 |
|
| 567 |
+
def normalize(t, mini=0.0, maxi=1.0):
|
| 568 |
+
"""
|
| 569 |
+
Normalizes a tensor to [0, 1].
|
| 570 |
+
If the tensor has more than 3 dimensions, the first one
|
| 571 |
+
is assumed to be the batch dimension and the tensor is
|
| 572 |
+
normalized per batch element, not across the batches.
|
| 573 |
+
|
| 574 |
+
Args:
|
| 575 |
+
t (torch.Tensor): Tensor to normalize
|
| 576 |
+
mini (float, optional): Min allowed value. Defaults to 0.
|
| 577 |
+
maxi (float, optional): Max allowed value. Defaults to 1.
|
| 578 |
+
|
| 579 |
+
Returns:
|
| 580 |
+
torch.Tensor: The normalized tensor
|
| 581 |
+
"""
|
| 582 |
if len(t.shape) == 3:
|
| 583 |
return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
|
| 584 |
|
| 585 |
batch_size = t.shape[0]
|
| 586 |
+
extra_dims = [1] * (t.ndim - 1)
|
| 587 |
+
min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, *extra_dims)
|
| 588 |
t = t - min_t
|
| 589 |
+
max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, *extra_dims)
|
| 590 |
t = t / max_t
|
| 591 |
return mini + (maxi - mini) * t
|
| 592 |
|
|
|
|
| 659 |
f.write(output)
|
| 660 |
|
| 661 |
|
| 662 |
+
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
| 663 |
delta = (res[0] / shape[0], res[1] / shape[1])
|
| 664 |
d = (shape[0] // res[0], shape[1] // res[1])
|
| 665 |
|
|
|
|
| 734 |
nps.append(n.astype(np.uint8))
|
| 735 |
|
| 736 |
return nps[0] if len(nps) == 1 else nps
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def tensor_to_uint8_numpy_image(tensor):
|
| 740 |
+
"""
|
| 741 |
+
Turns a BxCxHxW tensor into a numpy image:
|
| 742 |
+
* normalize
|
| 743 |
+
* to [0, 255]
|
| 744 |
+
* detach
|
| 745 |
+
* channels last
|
| 746 |
+
* to uin8
|
| 747 |
+
* to cpu
|
| 748 |
+
* to numpy
|
| 749 |
+
|
| 750 |
+
Args:
|
| 751 |
+
tensor (torch.Tensor): Tensor to transform
|
| 752 |
+
|
| 753 |
+
Returns:
|
| 754 |
+
np.array: BxHxWxC np.uint8 array in [0, 255]
|
| 755 |
+
"""
|
| 756 |
+
return (
|
| 757 |
+
normalize(tensor, 0, 255) # [0, 255]
|
| 758 |
+
.detach() # detach from graph if needed
|
| 759 |
+
.permute(0, 2, 3, 1) # BxHxWxC
|
| 760 |
+
.to(torch.uint8) # uint8
|
| 761 |
+
.cpu() # cpu
|
| 762 |
+
.numpy() # numpy array
|
| 763 |
+
)
|