Spaces:
Running
Running
gavinyuan
commited on
Commit
·
f057d66
1
Parent(s):
a104d3f
udpate: app.py import FSGenerator
Browse files- inference/tricks.py +6 -2
inference/tricks.py
CHANGED
|
@@ -138,10 +138,14 @@ class SoftErosion(nn.Module):
|
|
| 138 |
return x, mask
|
| 139 |
|
| 140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
vgg_mean = torch.tensor([[[0.485]], [[0.456]], [[0.406]]],
|
| 142 |
-
requires_grad=False, device=
|
| 143 |
vgg_std = torch.tensor([[[0.229]], [[0.224]], [[0.225]]],
|
| 144 |
-
requires_grad=False, device=
|
| 145 |
def load_bisenet():
|
| 146 |
bisenet_model = BiSeNet(n_classes=19)
|
| 147 |
bisenet_model.load_state_dict(
|
|
|
|
| 138 |
return x, mask
|
| 139 |
|
| 140 |
|
| 141 |
+
if torch.cuda.is_available():
|
| 142 |
+
device = torch.device(0)
|
| 143 |
+
else:
|
| 144 |
+
device = torch.device('cpu')
|
| 145 |
vgg_mean = torch.tensor([[[0.485]], [[0.456]], [[0.406]]],
|
| 146 |
+
requires_grad=False, device=device)
|
| 147 |
vgg_std = torch.tensor([[[0.229]], [[0.224]], [[0.225]]],
|
| 148 |
+
requires_grad=False, device=device)
|
| 149 |
def load_bisenet():
|
| 150 |
bisenet_model = BiSeNet(n_classes=19)
|
| 151 |
bisenet_model.load_state_dict(
|