52Hz commited on
Commit
306a5fa
·
1 Parent(s): 433e0fa

Update main_test_CMFNet.py

Browse files
Files changed (1) hide show
  1. main_test_CMFNet.py +7 -6
main_test_CMFNet.py CHANGED
@@ -55,14 +55,15 @@ def main():
55
  input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
56
  with torch.no_grad():
57
  restored = model(input_)
58
- restored = restored[0]
59
- restored = torch.clamp(restored, 0, 1)
60
- restored = restored[:, :, :h, :w]
61
- restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy()
62
- restored = img_as_ubyte(restored[0])
 
63
 
64
  f = os.path.splitext(os.path.split(file_)[-1])[0]
65
- save_img((os.path.join(out_dir, f + '.png')), restored)
66
 
67
 
68
  def save_img(filepath, img):
 
55
  input_ = F.pad(input_, (0, padw, 0, padh), 'reflect')
56
  with torch.no_grad():
57
  restored = model(input_)
58
+
59
+ restored_ = restored[0]
60
+ restored_= torch.clamp(restored_, 0, 1)
61
+ restored_= restored_[:, :, :h, :w]
62
+ restored_= restored_.permute(0, 2, 3, 1).cpu().detach().numpy()
63
+ restored_= img_as_ubyte(restored_[0])
64
 
65
  f = os.path.splitext(os.path.split(file_)[-1])[0]
66
+ save_img((os.path.join(out_dir, f + '.png')), restored_)
67
 
68
 
69
  def save_img(filepath, img):