ReLaX-VQA / src /utils /util.py
Xinyi Wang
first commit
211b431
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import seaborn as sns
def logistic_func(X, bayta1, bayta2, bayta3, bayta4):
logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4))))
yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart)
return yhat
def curve_bounds(x, params, sigma):
upper_bound = logistic_func(x, params[0] + 2 * sigma[0], params[1] + 2 * sigma[1], params[2] + 2 * sigma[2], params[3] + 2 * sigma[3])
lower_bound = logistic_func(x, params[0] - 2 * sigma[0], params[1] - 2 * sigma[1], params[2] - 2 * sigma[2], params[3] + 2 * sigma[3])
return upper_bound, lower_bound
# plot one
def plot_results(y_test, y_test_pred_logistic, df_pred_score, network_name, model_name, data_name, layer_name, select_criteria):
# nonlinear logistic fitted curve / logistic regression
mos = y_test
y = y_test_pred_logistic
try:
beta = [np.max(mos), np.min(mos), np.mean(y), 0.5]
popt, pcov = curve_fit(logistic_func, y, mos, p0=beta, maxfev=100000000)
sigma = np.sqrt(np.diag(pcov))
except:
raise Exception('Fitting logistic function time-out!!')
x_values = np.linspace(np.min(y), np.max(y), len(y))
plt.rcParams.update({'font.size': 24})
plt.figure(figsize=(10, 8))
plt.plot(x_values, logistic_func(x_values, *popt), '-', color='#c72e29', label='Fitted f(x)')
fig1 = sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df_pred_score, markers='o', color='steelblue', label=network_name, s=100)
# set the legend to a location outside the plot and specify the bbox_to_anchor
plt.legend(loc='lower right', fontsize=24, bbox_to_anchor=(1.0, 0.0))
plt.ylim(1, 5)
plt.xlim(1, 5)
title_name = f"Algorithm {network_name} with {model_name} on dataset {data_name}: {select_criteria}"
plt.title(title_name, fontsize=20)
plt.xlabel('Predicted Score', fontsize=24)
plt.ylabel('MOS', fontsize=24)
reg_fig1 = fig1.get_figure()
# save the file
# fig_path = f'../../figs/{data_name}/'
# if not os.path.exists(fig_path):
# os.makedirs(fig_path)
# fig_name = f"{network_name}_{layer_name}_{model_name}_{data_name}_by{select_criteria}.png"
# reg_fig1.savefig(f'{fig_path}{fig_name}', dpi=300, bbox_inches='tight')
plt.show()
plt.clf()
plt.close()
# plot comparison
def plot_comparison(df1, df2, network_name, model_name, data_name, layer_name, compare1, compare2):
fig, ax = plt.subplots(figsize=(10, 8))
sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df1, ax=ax, palette='colorblind', marker='o', s=100, label=compare1)
sns.scatterplot(x="y_test_pred_logistic", y="MOS", data=df2, ax=ax, palette='colorblind', marker='x', s=100, label=compare2)
# adjust marker edge width for each scatter plot
for scatter_plot in ax.collections:
scatter_plot.set_linewidth(1.5)
plt.rcParams.update({'font.size': 24})
plt.legend(loc='lower right', title="Comparison", fontsize=24, bbox_to_anchor=(1.0, 0.0))
plt.ylim(1, 5)
plt.xlim(1, 5)
title_name = f"Algorithm {network_name} with {model_name} on dataset {data_name}"
plt.title(title_name, fontsize=24)
plt.xlabel('Predicted Score', fontsize=24)
plt.ylabel('MOS', fontsize=24)
reg_fig = ax.get_figure()
# save the file
# fig_path = f'../../figs/{data_name}/'
# fig_name = f"{network_name}_{layer_name}_{model_name}_{data_name}.png"
# reg_fig.savefig(f'{fig_path}{fig_name}', dpi=300, bbox_inches='tight')
plt.show()
plt.clf()
plt.close()