qitaoz's picture
Upload 57 files
4562a06 verified
"""
python -m diffusionsfm.eval.eval_jobs --eval_path output/multi_diffusionsfm_dense --use_submitit
"""
import os
import json
import submitit
import argparse
import itertools
from glob import glob
import numpy as np
from tqdm.auto import tqdm
from diffusionsfm.dataset.co3d_v2 import TEST_CATEGORIES, TRAINING_CATEGORIES
from diffusionsfm.eval.eval_category import save_results
from diffusionsfm.utils.slurm import submitit_job_watcher
def evaluate_diffusionsfm(eval_path, use_submitit, mode):
JOB_PARAMS = {
"output_dir": [eval_path],
"checkpoint": [800_000],
"num_images": [2, 3, 4, 5, 6, 7, 8],
"sample_num": [0, 1, 2, 3, 4],
"category": TEST_CATEGORIES, # TRAINING_CATEGORIES + TEST_CATEGORIES,
"calculate_additional_timesteps": [True],
}
if mode == "test":
JOB_PARAMS["category"] = TEST_CATEGORIES
elif mode == "train1":
JOB_PARAMS["category"] = TRAINING_CATEGORIES[:len(TRAINING_CATEGORIES) // 2]
elif mode == "train2":
JOB_PARAMS["category"] = TRAINING_CATEGORIES[len(TRAINING_CATEGORIES) // 2:]
keys, values = zip(*JOB_PARAMS.items())
job_configs = [dict(zip(keys, p)) for p in itertools.product(*values)]
if use_submitit:
log_output = "./slurm_logs"
executor = submitit.AutoExecutor(
cluster=None, folder=log_output, slurm_max_num_timeout=10
)
# Use your own parameters
executor.update_parameters(
slurm_additional_parameters={
"nodes": 1,
"cpus-per-task": 5,
"gpus": 1,
"time": "6:00:00",
"partition": "all",
"exclude": "grogu-1-9, grogu-1-14,"
}
)
jobs = []
with executor.batch():
# This context manager submits all jobs at once at the end.
for params in job_configs:
job = executor.submit(save_results, **params)
job_param = f"{params['category']}_N{params['num_images']}_{params['sample_num']}"
jobs.append((job_param, job))
jobs = {f"{job_param}_{job.job_id}": job for job_param, job in jobs}
submitit_job_watcher(jobs)
else:
for job_config in tqdm(job_configs):
# This is much slower.
save_results(**job_config)
def process_predictions(eval_path, pred_index, checkpoint=800_000, threshold_R=15, threshold_CC=0.1):
"""
pred_index should be 1 (corresponding to T=90)
"""
def aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=None):
"""
Aggregates one metric over all data points in a prediction file and then across categories.
- For R_error and CC_error: use mean to threshold-based accuracy
- For CD and CD_Object: use median to reduce the effect of outliers
"""
per_category_values = []
for category in tqdm(categories, desc=f"Sample {sample_num}, N={num_images}, {metric_key}"):
per_pred_values = []
data_path = glob(
os.path.join(eval_path, "eval", f"{category}_{num_images}_{sample_num}_ckpt{checkpoint}*.json")
)[0]
with open(data_path) as f:
eval_data = json.load(f)
for preds in eval_data.values():
if metric_key in ["R_error", "CC_error"]:
vals = np.array(preds[pred_index][metric_key])
per_pred_values.append(np.mean(vals < threshold))
else:
per_pred_values.append(preds[pred_index][metric_key])
# Aggregate over all predictions within this category
per_category_values.append(
np.mean(per_pred_values) if metric_key in ["R_error", "CC_error"]
else np.median(per_pred_values) # CD or CD_Object — use median to filter outliers
)
if metric_key in ["R_error", "CC_error"]:
return np.mean(per_category_values)
else:
return np.median(per_category_values)
def aggregate_metric(categories, metric_key, num_images, threshold=None):
"""Aggregates one metric over 5 random samples per category and returns the final mean"""
return np.mean([
aggregate_per_category(categories, metric_key, num_images, sample_num, threshold=threshold)
for sample_num in range(5)
])
# Output containers
all_seen_acc_R, all_seen_acc_CC = [], []
all_seen_CD, all_seen_CD_Object = [], []
all_unseen_acc_R, all_unseen_acc_CC = [], []
all_unseen_CD, all_unseen_CD_Object = [], []
for num_images in range(2, 9):
# Seen categories
all_seen_acc_R.append(
aggregate_metric(TRAINING_CATEGORIES, "R_error", num_images, threshold=threshold_R)
)
all_seen_acc_CC.append(
aggregate_metric(TRAINING_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
)
all_seen_CD.append(
aggregate_metric(TRAINING_CATEGORIES, "CD", num_images)
)
all_seen_CD_Object.append(
aggregate_metric(TRAINING_CATEGORIES, "CD_Object", num_images)
)
# Unseen categories
all_unseen_acc_R.append(
aggregate_metric(TEST_CATEGORIES, "R_error", num_images, threshold=threshold_R)
)
all_unseen_acc_CC.append(
aggregate_metric(TEST_CATEGORIES, "CC_error", num_images, threshold=threshold_CC)
)
all_unseen_CD.append(
aggregate_metric(TEST_CATEGORIES, "CD", num_images)
)
all_unseen_CD_Object.append(
aggregate_metric(TEST_CATEGORIES, "CD_Object", num_images)
)
# Print the results in formatted rows
print("N= ", " ".join(f"{i: 5}" for i in range(2, 9)))
print("Seen R ", " ".join([f"{x:0.3f}" for x in all_seen_acc_R]))
print("Seen CC ", " ".join([f"{x:0.3f}" for x in all_seen_acc_CC]))
print("Seen CD ", " ".join([f"{x:0.3f}" for x in all_seen_CD]))
print("Seen CD_Obj ", " ".join([f"{x:0.3f}" for x in all_seen_CD_Object]))
print("Unseen R ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_R]))
print("Unseen CC ", " ".join([f"{x:0.3f}" for x in all_unseen_acc_CC]))
print("Unseen CD ", " ".join([f"{x:0.3f}" for x in all_unseen_CD]))
print("Unseen CD_Obj", " ".join([f"{x:0.3f}" for x in all_unseen_CD_Object]))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--eval_path", type=str, default=None)
parser.add_argument("--use_submitit", action="store_true")
parser.add_argument("--mode", type=str, default="test")
args = parser.parse_args()
eval_path = "output/multi_diffusionsfm_dense" if args.eval_path is None else args.eval_path
use_submitit = args.use_submitit
mode = args.mode
evaluate_diffusionsfm(eval_path, use_submitit, mode)
process_predictions(eval_path, 1)