Spaces:
Runtime error
Runtime error
Update Self-Correction-Human-Parsing/simple_extractor.py
Browse files
Self-Correction-Human-Parsing/simple_extractor.py
CHANGED
@@ -10,7 +10,7 @@ import torchvision.transforms as transforms
|
|
10 |
|
11 |
import networks
|
12 |
from utils.transforms import transform_logits
|
13 |
-
from datasets.simple_extractor_dataset import
|
14 |
|
15 |
dataset_settings = {
|
16 |
'atr': {
|
@@ -27,7 +27,7 @@ def get_arguments():
|
|
27 |
parser.add_argument("--dataset", type=str, default='atr', choices=['atr'])
|
28 |
parser.add_argument("--model-restore", type=str, default='', help="Path to pretrained model.")
|
29 |
parser.add_argument("--gpu", type=str, default='0', help="GPU device.")
|
30 |
-
parser.add_argument("--input-
|
31 |
parser.add_argument("--output-dir", type=str, default='', help="Path of output image folder.")
|
32 |
|
33 |
return parser.parse_args()
|
@@ -57,8 +57,9 @@ def main():
|
|
57 |
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
|
58 |
])
|
59 |
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
if not os.path.exists(args.output_dir):
|
64 |
os.makedirs(args.output_dir)
|
|
|
10 |
|
11 |
import networks
|
12 |
from utils.transforms import transform_logits
|
13 |
+
from datasets.simple_extractor_dataset import SimpleFileDataset # Modify dataset class to use SimpleFileDataset
|
14 |
|
15 |
dataset_settings = {
|
16 |
'atr': {
|
|
|
27 |
parser.add_argument("--dataset", type=str, default='atr', choices=['atr'])
|
28 |
parser.add_argument("--model-restore", type=str, default='', help="Path to pretrained model.")
|
29 |
parser.add_argument("--gpu", type=str, default='0', help="GPU device.")
|
30 |
+
parser.add_argument("--input-path", type=str, default='', help="Path to a single input image.")
|
31 |
parser.add_argument("--output-dir", type=str, default='', help="Path of output image folder.")
|
32 |
|
33 |
return parser.parse_args()
|
|
|
57 |
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
|
58 |
])
|
59 |
|
60 |
+
# Use the SimpleFileDataset class instead of SimpleFolderDataset
|
61 |
+
dataset = SimpleFileDataset(img_path=args.input_path, input_size=input_size, transform=transform)
|
62 |
+
dataloader = DataLoader(dataset, batch_size=1) # Only one image, so batch_size=1
|
63 |
|
64 |
if not os.path.exists(args.output_dir):
|
65 |
os.makedirs(args.output_dir)
|