Spaces:
Runtime error
Runtime error
| """ | |
| This script plots the result of the human evaluation on Amazon Mechanical Turk, where | |
| human participants chose between an image from ClimateGAN or from a different method. | |
| """ | |
| print("Imports...", end="") | |
| from argparse import ArgumentParser | |
| import os | |
| import yaml | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| from pathlib import Path | |
| import matplotlib.pyplot as plt | |
| # ----------------------- | |
| # ----- Constants ----- | |
| # ----------------------- | |
| comparables_dict = { | |
| "munit_flooded": "MUNIT", | |
| "cyclegan": "CycleGAN", | |
| "instagan": "InstaGAN", | |
| "instagan_copypaste": "Mask-InstaGAN", | |
| "painted_ground": "Painted ground", | |
| } | |
| # Colors | |
| palette_colorblind = sns.color_palette("colorblind") | |
| color_climategan = palette_colorblind[9] | |
| palette_colorblind = sns.color_palette("colorblind") | |
| color_munit = palette_colorblind[1] | |
| color_cyclegan = palette_colorblind[2] | |
| color_instagan = palette_colorblind[3] | |
| color_maskinstagan = palette_colorblind[6] | |
| color_paintedground = palette_colorblind[8] | |
| palette_comparables = [ | |
| color_munit, | |
| color_cyclegan, | |
| color_instagan, | |
| color_maskinstagan, | |
| color_paintedground, | |
| ] | |
| palette_comparables_light = [ | |
| sns.light_palette(color, n_colors=3)[1] for color in palette_comparables | |
| ] | |
| def parsed_args(): | |
| """ | |
| Parse and returns command-line args | |
| Returns: | |
| argparse.Namespace: the parsed arguments | |
| """ | |
| parser = ArgumentParser() | |
| parser.add_argument( | |
| "--input_csv", | |
| default="amt_omni-vs-other.csv", | |
| type=str, | |
| help="CSV containing the results of the human evaluation, pre-processed", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| default=None, | |
| type=str, | |
| help="Output directory", | |
| ) | |
| parser.add_argument( | |
| "--dpi", | |
| default=200, | |
| type=int, | |
| help="DPI for the output images", | |
| ) | |
| parser.add_argument( | |
| "--n_bs", | |
| default=1e6, | |
| type=int, | |
| help="Number of bootrstrap samples", | |
| ) | |
| parser.add_argument( | |
| "--bs_seed", | |
| default=17, | |
| type=int, | |
| help="Bootstrap random seed, for reproducibility", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| # ----------------------------- | |
| # ----- Parse arguments ----- | |
| # ----------------------------- | |
| args = parsed_args() | |
| print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()])) | |
| # Determine output dir | |
| if args.output_dir is None: | |
| output_dir = Path(os.environ["SLURM_TMPDIR"]) | |
| else: | |
| output_dir = Path(args.output_dir) | |
| if not output_dir.exists(): | |
| output_dir.mkdir(parents=True, exist_ok=False) | |
| # Store args | |
| output_yml = output_dir / "args_human_evaluation.yml" | |
| with open(output_yml, "w") as f: | |
| yaml.dump(vars(args), f) | |
| # Read CSV | |
| df = pd.read_csv(args.input_csv) | |
| # Sort Y labels | |
| comparables = df.comparable.unique() | |
| is_climategan_sum = [ | |
| df.loc[df.comparable == c, "climategan"].sum() for c in comparables | |
| ] | |
| comparables = comparables[np.argsort(is_climategan_sum)[::-1]] | |
| # Plot setup | |
| sns.set(style="whitegrid") | |
| plt.rcParams.update({"font.family": "serif"}) | |
| plt.rcParams.update( | |
| { | |
| "font.serif": [ | |
| "Computer Modern Roman", | |
| "Times New Roman", | |
| "Utopia", | |
| "New Century Schoolbook", | |
| "Century Schoolbook L", | |
| "ITC Bookman", | |
| "Bookman", | |
| "Times", | |
| "Palatino", | |
| "Charter", | |
| "serif" "Bitstream Vera Serif", | |
| "DejaVu Serif", | |
| ] | |
| } | |
| ) | |
| fontsize = "medium" | |
| # Initialize the matplotlib figure | |
| fig, ax = plt.subplots(figsize=(10.5, 3), dpi=args.dpi) | |
| # Plot the total (right) | |
| sns.barplot( | |
| data=df.loc[df.is_valid], | |
| x="is_valid", | |
| y="comparable", | |
| order=comparables, | |
| orient="h", | |
| label="comparable", | |
| palette=palette_comparables_light, | |
| ci=None, | |
| ) | |
| # Plot the left | |
| sns.barplot( | |
| data=df.loc[df.is_valid], | |
| x="climategan", | |
| y="comparable", | |
| order=comparables, | |
| orient="h", | |
| label="climategan", | |
| color=color_climategan, | |
| ci=99, | |
| n_boot=args.n_bs, | |
| seed=args.bs_seed, | |
| errcolor="black", | |
| errwidth=1.5, | |
| capsize=0.1, | |
| ) | |
| # Draw line at 0.5 | |
| y = np.arange(ax.get_ylim()[1] + 0.1, ax.get_ylim()[0], 0.1) | |
| x = 0.5 * np.ones(y.shape[0]) | |
| ax.plot(x, y, linestyle=":", linewidth=1.5, color="black") | |
| # Change Y-Tick labels | |
| yticklabels = [comparables_dict[ytick.get_text()] for ytick in ax.get_yticklabels()] | |
| yticklabels_text = ax.set_yticklabels( | |
| yticklabels, fontsize=fontsize, horizontalalignment="right", x=0.96 | |
| ) | |
| for ytl in yticklabels_text: | |
| ax.add_artist(ytl) | |
| # Remove Y-label | |
| ax.set_ylabel(ylabel="") | |
| # Change X-Tick labels | |
| xlim = [0.0, 1.1] | |
| xticks = np.arange(xlim[0], xlim[1], 0.1) | |
| ax.set(xticks=xticks) | |
| plt.setp(ax.get_xticklabels(), fontsize=fontsize) | |
| # Set X-label | |
| ax.set_xlabel(None) | |
| # Change spines | |
| sns.despine(left=True, bottom=True) | |
| # Save figure | |
| output_fig = output_dir / "human_evaluation_rate_climategan.png" | |
| fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight") | |