|
import mlflow |
|
from mlflow.tracking import MlflowClient |
|
from mlflow.entities import ViewType |
|
import argparse |
|
|
|
import os |
|
import pathlib |
|
import shutil |
|
import imageio |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
parser = argparse.ArgumentParser(description="results_analysis") |
|
parser.add_argument("--tracking_uri", type=str, |
|
default="http://deplo-mlflo-1ssxo94f973sj-890390d809901dbf.elb.eu-central-1.amazonaws.com", help='URI of the mlflow server on AWS') |
|
parser.add_argument("--experiment_name", type=str, default=None, |
|
help='Name of the experiment on the mlflow server, e.g. "processing_comparison"') |
|
parser.add_argument("--run_name", type=str, default=None, |
|
help='Name of the run on the mlflow server, e.g. "proc_nn"') |
|
parser.add_argument("--representation", type=str, default=None, |
|
choices=["processing", "gradients"], help='The representation form you want retrieve("processing" or "gradients")') |
|
parser.add_argument("--step", type=str, default=None, |
|
choices=["pre_debayer", "demosaic", "color_correct", "sharpening", "gaussian", "clipped", "gamma_correct", "rgb"], |
|
help='The processing step you want to track ("pre_debayer" or "rgb")') |
|
parser.add_argument("--gif_name", type=str, default=None, |
|
help='Name of the gif that will be saved. Note: .gif will be added later by script') |
|
|
|
parser.add_argument("--local_dir", type=str, default=None, |
|
help='Name of the local dir to be created to store mlflow data') |
|
parser.add_argument("--cleanup", type=bool, default=True, |
|
help='Whether to delete the local dir again after the script was run') |
|
parser.add_argument("--output", type=str, default=None, |
|
choices=["gif", "train_vs_val_loss"], |
|
help='Which output to generate') |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
mlflow.set_tracking_uri(args.tracking_uri) |
|
|
|
|
|
|
|
|
|
|
|
experiment = mlflow.get_experiment_by_name(args.experiment_name) |
|
|
|
|
|
|
|
|
|
|
|
filter_string = "tags.mlflow.runName = '{}'".format(args.run_name) |
|
runs = mlflow.search_runs(experiment.experiment_id, filter_string=filter_string) |
|
client = MlflowClient() |
|
|
|
if args.output == "gif": |
|
|
|
|
|
if args.local_dir: |
|
local_dir = args.local_dir+"/artifacts" |
|
else: |
|
local_dir = str(pathlib.Path().resolve())+"/artifacts" |
|
if not os.path.isdir('artifacts'): |
|
os.mkdir(local_dir) |
|
dir = client.download_artifacts(runs["run_id"][0], "results", local_dir) |
|
|
|
|
|
dirs = [x[0] for x in os.walk(dir)] |
|
dirs = sorted(dirs, key=str.lower)[1:] |
|
|
|
with imageio.get_writer(args.gif_name+'.gif', mode='I') as writer: |
|
for epoch in dirs: |
|
for _, _, files in os.walk(epoch): |
|
for name in files: |
|
if args.representation in name and args.step in name and "png" in name: |
|
image = imageio.imread(epoch+"/"+name) |
|
writer.append_data(image) |
|
|
|
|
|
if args.cleanup: |
|
shutil.rmtree(local_dir) |
|
|
|
elif args.output == "train_vs_val_loss": |
|
train_loss = client.get_metric_history(runs["run_id"][0], "train_loss") |
|
val_loss = client.get_metric_history(runs["run_id"][0], "val_loss") |
|
train_loss = sorted(train_loss, key=lambda m: m.step) |
|
val_loss = sorted(val_loss, key=lambda m: m.step) |
|
plt.figure() |
|
for m_train, m_val in zip(train_loss, val_loss): |
|
plt.scatter(m_train.value, m_val.value, alpha=1/(m_train.step+1), color='blue') |
|
plt.savefig("scatter.png") |
|
|