Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -93,8 +93,8 @@ size = 256
|
|
| 93 |
means = [0.485, 0.456, 0.406]
|
| 94 |
stds = [0.229, 0.224, 0.225]
|
| 95 |
|
| 96 |
-
t_stds = torch.tensor(stds).cpu()[:,None,None]
|
| 97 |
-
t_means = torch.tensor(means).cpu()[:,None,None]
|
| 98 |
|
| 99 |
def makeEven(_x):
|
| 100 |
return int(_x) if (_x % 2 == 0) else int(_x+1)
|
|
@@ -107,7 +107,7 @@ def tensor2im(var):
|
|
| 107 |
return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
|
| 108 |
|
| 109 |
def proc_pil_img(input_image, model):
|
| 110 |
-
transformed_image = img_transforms(input_image)[None,...].cpu()()
|
| 111 |
|
| 112 |
with torch.no_grad():
|
| 113 |
result_image = model(transformed_image)[0]
|
|
@@ -118,9 +118,9 @@ def proc_pil_img(input_image, model):
|
|
| 118 |
|
| 119 |
|
| 120 |
|
| 121 |
-
modelv4 = torch.jit.load(modelarcanev4).eval().cpu()
|
| 122 |
-
modelv3 = torch.jit.load(modelarcanev3).eval().cpu()
|
| 123 |
-
modelv2 = torch.jit.load(modelarcanev2).eval().cpu()
|
| 124 |
|
| 125 |
def process(im, version):
|
| 126 |
if version == 'version 0.4':
|
|
|
|
| 93 |
means = [0.485, 0.456, 0.406]
|
| 94 |
stds = [0.229, 0.224, 0.225]
|
| 95 |
|
| 96 |
+
t_stds = torch.tensor(stds).cpu().half().float()[:,None,None]
|
| 97 |
+
t_means = torch.tensor(means).cpu().half().float()[:,None,None]
|
| 98 |
|
| 99 |
def makeEven(_x):
|
| 100 |
return int(_x) if (_x % 2 == 0) else int(_x+1)
|
|
|
|
| 107 |
return var.mul(t_stds).add(t_means).mul(255.).clamp(0,255).permute(1,2,0)
|
| 108 |
|
| 109 |
def proc_pil_img(input_image, model):
|
| 110 |
+
transformed_image = img_transforms(input_image)[None,...].cpu().half().float()
|
| 111 |
|
| 112 |
with torch.no_grad():
|
| 113 |
result_image = model(transformed_image)[0]
|
|
|
|
| 118 |
|
| 119 |
|
| 120 |
|
| 121 |
+
modelv4 = torch.jit.load(modelarcanev4,map_location='cpu').eval().cpu().half().float()
|
| 122 |
+
modelv3 = torch.jit.load(modelarcanev3,map_location='cpu').eval().cpu().half().float()
|
| 123 |
+
modelv2 = torch.jit.load(modelarcanev2,map_location='cpu').eval().cpu().half().float()
|
| 124 |
|
| 125 |
def process(im, version):
|
| 126 |
if version == 'version 0.4':
|