Spaces:
Runtime error
Runtime error
add tqdm
Browse files
train.py
CHANGED
|
@@ -71,35 +71,6 @@ def test():
|
|
| 71 |
conrmodel.dist()
|
| 72 |
infer(args, conrmodel, image_names_list)
|
| 73 |
|
| 74 |
-
# def test():
|
| 75 |
-
# source_names_list = []
|
| 76 |
-
# for name in os.listdir(args.test_input_person_images):
|
| 77 |
-
# thissource = os.path.join(args.test_input_person_images, name)
|
| 78 |
-
# if os.path.isfile(thissource):
|
| 79 |
-
# source_names_list.append([thissource])
|
| 80 |
-
# if os.path.isdir(thissource):
|
| 81 |
-
# toadd = [os.path.join(thissource, this_file)
|
| 82 |
-
# for this_file in os.listdir(thissource)]
|
| 83 |
-
# if (toadd != []):
|
| 84 |
-
# source_names_list.append(toadd)
|
| 85 |
-
# else:
|
| 86 |
-
# print("skipping empty folder :"+thissource)
|
| 87 |
-
# image_names_list = []
|
| 88 |
-
# for eachlist in source_names_list:
|
| 89 |
-
# for name in sorted(os.listdir(args.test_input_poses_images)):
|
| 90 |
-
# thistarget = os.path.join(args.test_input_poses_images, name)
|
| 91 |
-
# if os.path.isfile(thistarget):
|
| 92 |
-
# image_names_list.append([thistarget, *eachlist])
|
| 93 |
-
# if os.path.isdir(thistarget):
|
| 94 |
-
# print("skipping folder :"+thistarget)
|
| 95 |
-
|
| 96 |
-
# print(image_names_list)
|
| 97 |
-
# print("---building models...")
|
| 98 |
-
# conrmodel = CoNR(args)
|
| 99 |
-
# conrmodel.load_model(path=args.test_checkpoint_dir)
|
| 100 |
-
# conrmodel.dist()
|
| 101 |
-
# infer(args, conrmodel, image_names_list)
|
| 102 |
-
|
| 103 |
|
| 104 |
def infer(args, humanflowmodel, image_names_list):
|
| 105 |
print("---test images: ", len(image_names_list))
|
|
@@ -124,7 +95,9 @@ def infer(args, humanflowmodel, image_names_list):
|
|
| 124 |
time_stamp = time.time()
|
| 125 |
prev_frame_rgb = []
|
| 126 |
prev_frame_a = []
|
| 127 |
-
|
|
|
|
|
|
|
| 128 |
data_time_interval = time.time() - time_stamp
|
| 129 |
time_stamp = time.time()
|
| 130 |
with torch.no_grad():
|
|
@@ -138,11 +111,10 @@ def infer(args, humanflowmodel, image_names_list):
|
|
| 138 |
|
| 139 |
train_time_interval = time.time() - time_stamp
|
| 140 |
time_stamp = time.time()
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
# ))
|
| 146 |
with torch.no_grad():
|
| 147 |
|
| 148 |
if args.test_output_video:
|
|
|
|
| 71 |
conrmodel.dist()
|
| 72 |
infer(args, conrmodel, image_names_list)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
def infer(args, humanflowmodel, image_names_list):
|
| 76 |
print("---test images: ", len(image_names_list))
|
|
|
|
| 95 |
time_stamp = time.time()
|
| 96 |
prev_frame_rgb = []
|
| 97 |
prev_frame_a = []
|
| 98 |
+
|
| 99 |
+
pbar = tqdm(range(train_num), ncols=100)
|
| 100 |
+
for i, data in enumerate(train_data):
|
| 101 |
data_time_interval = time.time() - time_stamp
|
| 102 |
time_stamp = time.time()
|
| 103 |
with torch.no_grad():
|
|
|
|
| 111 |
|
| 112 |
train_time_interval = time.time() - time_stamp
|
| 113 |
time_stamp = time.time()
|
| 114 |
+
if args.local_rank == 0:
|
| 115 |
+
pbar.set_description(f"Epoch {i}/{train_num}")
|
| 116 |
+
pbar.set_postfix({"data_time": data_time_interval, "train_time":train_time_interval})
|
| 117 |
+
|
|
|
|
| 118 |
with torch.no_grad():
|
| 119 |
|
| 120 |
if args.test_output_video:
|