Keshabwi66 commited on
Commit
3a5cd6f
·
verified ·
1 Parent(s): c594e13

Update Self-Correction-Human-Parsing/simple_extractor.py

Browse files
Self-Correction-Human-Parsing/simple_extractor.py CHANGED
@@ -10,7 +10,7 @@ import torchvision.transforms as transforms
10
 
11
  import networks
12
  from utils.transforms import transform_logits
13
- from datasets.simple_extractor_dataset import SimpleFolderDataset
14
 
15
  dataset_settings = {
16
  'atr': {
@@ -27,7 +27,7 @@ def get_arguments():
27
  parser.add_argument("--dataset", type=str, default='atr', choices=['atr'])
28
  parser.add_argument("--model-restore", type=str, default='', help="Path to pretrained model.")
29
  parser.add_argument("--gpu", type=str, default='0', help="GPU device.")
30
- parser.add_argument("--input-dir", type=str, default='', help="Path of input image folder.")
31
  parser.add_argument("--output-dir", type=str, default='', help="Path of output image folder.")
32
 
33
  return parser.parse_args()
@@ -57,8 +57,9 @@ def main():
57
  transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
58
  ])
59
 
60
- dataset = SimpleFolderDataset(root=args.input_dir, input_size=input_size, transform=transform)
61
- dataloader = DataLoader(dataset)
 
62
 
63
  if not os.path.exists(args.output_dir):
64
  os.makedirs(args.output_dir)
 
10
 
11
  import networks
12
  from utils.transforms import transform_logits
13
+ from datasets.simple_extractor_dataset import SimpleFileDataset # Modify dataset class to use SimpleFileDataset
14
 
15
  dataset_settings = {
16
  'atr': {
 
27
  parser.add_argument("--dataset", type=str, default='atr', choices=['atr'])
28
  parser.add_argument("--model-restore", type=str, default='', help="Path to pretrained model.")
29
  parser.add_argument("--gpu", type=str, default='0', help="GPU device.")
30
+ parser.add_argument("--input-path", type=str, default='', help="Path to a single input image.")
31
  parser.add_argument("--output-dir", type=str, default='', help="Path of output image folder.")
32
 
33
  return parser.parse_args()
 
57
  transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
58
  ])
59
 
60
+ # Use the SimpleFileDataset class instead of SimpleFolderDataset
61
+ dataset = SimpleFileDataset(img_path=args.input_path, input_size=input_size, transform=transform)
62
+ dataloader = DataLoader(dataset, batch_size=1) # Only one image, so batch_size=1
63
 
64
  if not os.path.exists(args.output_dir):
65
  os.makedirs(args.output_dir)