Spaces:
Runtime error
Runtime error
anonymous
commited on
Commit
·
2a8678c
1
Parent(s):
4fcfd85
update
Browse files- src/img_util.py +3 -1
src/img_util.py
CHANGED
|
@@ -2,6 +2,8 @@ import einops
|
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
|
|
|
|
|
|
| 5 |
|
| 6 |
@torch.no_grad()
|
| 7 |
def find_flat_region(mask):
|
|
@@ -18,6 +20,6 @@ def find_flat_region(mask):
|
|
| 18 |
|
| 19 |
|
| 20 |
def numpy2tensor(img):
|
| 21 |
-
x0 = torch.from_numpy(img.copy()).float().
|
| 22 |
x0 = torch.stack([x0], dim=0)
|
| 23 |
return einops.rearrange(x0, 'b h w c -> b c h w').clone()
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
|
| 5 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 6 |
+
|
| 7 |
|
| 8 |
@torch.no_grad()
|
| 9 |
def find_flat_region(mask):
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def numpy2tensor(img):
|
| 23 |
+
x0 = torch.from_numpy(img.copy()).float().to(device) / 255.0 * 2.0 - 1.
|
| 24 |
x0 = torch.stack([x0], dim=0)
|
| 25 |
return einops.rearrange(x0, 'b h w c -> b c h w').clone()
|