Spaces:
Runtime error
Runtime error
Update annotator/hed/__init__.py
Browse files
annotator/hed/__init__.py
CHANGED
|
@@ -100,13 +100,20 @@ class HEDdetector:
|
|
| 100 |
if not os.path.exists(modelpath):
|
| 101 |
from basicsr.utils.download_util import load_file_from_url
|
| 102 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def __call__(self, input_image):
|
| 106 |
assert input_image.ndim == 3
|
| 107 |
input_image = input_image[:, :, ::-1].copy()
|
| 108 |
with torch.no_grad():
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
| 110 |
image_hed = image_hed / 255.0
|
| 111 |
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
| 112 |
edge = self.netNetwork(image_hed)[0]
|
|
|
|
| 100 |
if not os.path.exists(modelpath):
|
| 101 |
from basicsr.utils.download_util import load_file_from_url
|
| 102 |
load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
|
| 103 |
+
if torch.cuda.is_available():
|
| 104 |
+
self.netNetwork = Network(modelpath).cuda().eval()
|
| 105 |
+
else:
|
| 106 |
+
self.netNetwork = Network(modelpath).eval()
|
| 107 |
+
|
| 108 |
|
| 109 |
def __call__(self, input_image):
|
| 110 |
assert input_image.ndim == 3
|
| 111 |
input_image = input_image[:, :, ::-1].copy()
|
| 112 |
with torch.no_grad():
|
| 113 |
+
if torch.cuda.is_available():
|
| 114 |
+
image_hed = torch.from_numpy(input_image).float().cuda()
|
| 115 |
+
else:
|
| 116 |
+
image_hed = torch.from_numpy(input_image).float()
|
| 117 |
image_hed = image_hed / 255.0
|
| 118 |
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
| 119 |
edge = self.netNetwork(image_hed)[0]
|