Spaces:
Running
on
T4
Running
on
T4
""" | |
python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit | |
""" | |
import os | |
import json | |
import submitit | |
import argparse | |
import itertools | |
from glob import glob | |
import numpy as np | |
from tqdm.auto import tqdm | |
from diffusionsfm.dataset.co3d_v2 import TEST_CATEGORIES, TRAINING_CATEGORIES | |
from diffusionsfm.eval.eval_category import save_results | |
from diffusionsfm.utils.slurm import submitit_job_watcher | |
def evaluate_diffusionsfm(eval_path, use_submitit, mode): | |
JOB_PARAMS = { | |
"output_dir": [eval_path], | |
"checkpoint": [800_000], | |
"num_images": [2, 3, 4, 5, 6, 7, 8], | |
"sample_num": [0, 1, 2, 3, 4], | |
"category": TEST_CATEGORIES, # TRAINING_CATEGORIES + TEST_CATEGORIES, | |
"calculate_additional_timesteps": [True], | |
} | |
if mode == "test": | |
JOB_PARAMS["category"] = TEST_CATEGORIES | |
elif mode == "train1": | |
JOB_PARAMS["category"] = TRAINING_CATEGORIES[:len(TRAINING_CATEGORIES) // 2] | |
elif mode == "train2": | |
JOB_PARAMS["category"] = TRAINING_CATEGORIES[len(TRAINING_CATEGORIES) // 2:] | |
keys, values = zip(*JOB_PARAMS.items()) | |
job_configs = [dict(zip(keys, p)) for p in itertools.product(*values)] | |
if use_submitit: | |
log_output = "./slurm_logs" | |
executor = submitit.AutoExecutor( | |
cluster=None, folder=log_output, slurm_max_num_timeout=10 | |
) | |
# Use your own parameters | |
executor.update_parameters( | |
slurm_additional_parameters={ | |
"nodes": 1, | |
"cpus-per-task": 5, | |
"gpus": 1, | |
"time": "6:00:00", | |
"partition": "all", | |
"exclude": "grogu-1-9, grogu-1-14," | |
} | |
) | |
jobs = [] | |
with executor.batch(): | |
# This context manager submits all jobs at once at the end. | |
for params in job_configs: | |
job = executor.submit(save_results, **params) | |
job_param = f"{params['category']}_N{params['num_images']}_{params['sample_num']}" | |
jobs.append((job_param, job)) | |
jobs = {f"{job_param}_{job.job_id}": job for job_param, job in jobs} | |
submitit_job_watcher(jobs) | |
else: | |
for job_config in tqdm(job_configs): | |
# This is much slower. | |
save_results(**job_config) | |
def process_predictions(eval_path, pred_index, checkpoint=800_000, threshold_R=15, threshold_CC=0.1): | |
""" | |
pred_index should be 1 (corresponding to T=90) | |
""" | |
def aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=None): | |
""" | |
Aggregates one metric over all data points in a prediction file and then across categories. | |
- For R_error and CC_error: use mean to threshold-based accuracy | |
- For CD and CD_Object: use median to reduce the effect of outliers | |
""" | |
per_category_values = [] | |
for category in tqdm(categories, desc=f"Sample {sample_num}, N={num_images}, {metric_key}"): | |
per_pred_values = [] | |
data_path = glob( | |
os.path.join(eval_path, "eval", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}*.json") | |
)[0] | |
with open(data_path) as f: | |
eval_data = json.load(f) | |
for preds in eval_data.values(): | |
if metric_key in ["R_error", "CC_error"]: | |
vals = np.array(preds[pred_index][metric_key]) | |
per_pred_values.append(np.mean(vals < threshold)) | |
else: | |
per_pred_values.append(preds[pred_index][metric_key]) | |
# Aggregate over all predictions within this category | |
per_category_values.append( | |
np.mean(per_pred_values) if metric_key in ["R_error", "CC_error"] | |
else np.median(per_pred_values) # CD or CD_Object — use median to filter outliers | |
) | |
if metric_key in ["R_error", "CC_error"]: | |
return np.mean(per_category_values) | |
else: | |
return np.median(per_category_values) | |
def aggregate_metric(categories, metric_key, num_images, threshold=None): | |
"""Aggregates one metric over 5 random samples per category and returns the final mean""" | |
return np.mean([ | |
aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=threshold) | |
for sample_num in range(5) | |
]) | |
# Output containers | |
all_seen_acc_R, all_seen_acc_CC = [], [] | |
all_seen_CD, all_seen_CD_Object = [], [] | |
all_unseen_acc_R, all_unseen_acc_CC = [], [] | |
all_unseen_CD, all_unseen_CD_Object = [], [] | |
for num_images in range(2, 9): | |
# Seen categories | |
all_seen_acc_R.append( | |
aggregate_metric(TRAINING_CATEGORIES, "R_error", num_images, threshold=threshold_R) | |
) | |
all_seen_acc_CC.append( | |
aggregate_metric(TRAINING_CATEGORIES, "CC_error", num_images, threshold=threshold_CC) | |
) | |
all_seen_CD.append( | |
aggregate_metric(TRAINING_CATEGORIES, "CD", num_images) | |
) | |
all_seen_CD_Object.append( | |
aggregate_metric(TRAINING_CATEGORIES, "CD_Object", num_images) | |
) | |
# Unseen categories | |
all_unseen_acc_R.append( | |
aggregate_metric(TEST_CATEGORIES, "R_error", num_images, threshold=threshold_R) | |
) | |
all_unseen_acc_CC.append( | |
aggregate_metric(TEST_CATEGORIES, "CC_error", num_images, threshold=threshold_CC) | |
) | |
all_unseen_CD.append( | |
aggregate_metric(TEST_CATEGORIES, "CD", num_images) | |
) | |
all_unseen_CD_Object.append( | |
aggregate_metric(TEST_CATEGORIES, "CD_Object", num_images) | |
) | |
# Print the results in formatted rows | |
print("N= ", " ".join(f"{i: 5}" for i in range(2, 9))) | |
print("Seen R ", " ".join([f"{x:0.3f}" for x in all_seen_acc_R])) | |
print("Seen CC ", " ".join([f"{x:0.3f}" for x in all_seen_acc_CC])) | |
print("Seen CD ", " ".join([f"{x:0.3f}" for x in all_seen_CD])) | |
print("Seen CD_Obj ", " ".join([f"{x:0.3f}" for x in all_seen_CD_Object])) | |
print("Unseen R ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_R])) | |
print("Unseen CC ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_CC])) | |
print("Unseen CD ", " ".join([f"{x:0.3f}" for x in all_unseen_CD])) | |
print("Unseen CD_Obj", " ".join([f"{x:0.3f}" for x in all_unseen_CD_Object])) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--eval_path", type=str, default=None) | |
parser.add_argument("--use_submitit", action="store_true") | |
parser.add_argument("--mode", type=str, default="test") | |
args = parser.parse_args() | |
eval_path = "output/multi_diffusionsfm_dense" if args.eval_path is None else args.eval_path | |
use_submitit = args.use_submitit | |
mode = args.mode | |
evaluate_diffusionsfm(eval_path, use_submitit, mode) | |
process_predictions(eval_path, 1) |