Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add Descriptive Statistics and Charts #37

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 106 additions & 8 deletions nisqa/NISQA_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import yaml
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch import optim
from torch.utils.data import DataLoader
from . import NISQA_lib as NL
Expand Down Expand Up @@ -50,7 +51,81 @@ def evaluate(self, mapping='first_order', do_print=True, do_plot=False):
self._evaluate_dim(mapping=mapping, do_print=do_print, do_plot=do_plot)
else:
self._evaluate_mos(mapping=mapping, do_print=do_print, do_plot=do_plot)


def _draw_barchart(self):
# settings of barchart
bar_width = 0.15
index = np.arange(len(self.ds_val.df['deg']))
colors = ['r', 'g', 'b', 'c', 'm']
fig, ax = plt.subplots(figsize=(12, 8))
for i, col in enumerate(['mos_pred', 'noi_pred', 'dis_pred', 'col_pred', 'loud_pred']):
ax.bar(index + (i * bar_width), self.ds_val.df[col], bar_width, color=colors[i], label=col)
for j, val in enumerate(self.ds_val.df[col]):
ax.annotate(str(f'{val:.2f}'), xy=(index[j] + (i * bar_width), val), xytext=(0, 3),
textcoords="offset points", ha='center', va='bottom')
# set labels and title
ax.set_xlabel('WavName')
ax.set_ylabel('Scores')
title = 'BarChart for '+str(self.args['deg']) if self.args['mode'] == 'predict_file' else 'Barchart for Wavs under: '+str(self.args['data_dir'])
ax.set_title(title)
ax.set_xticks(index + (2 * bar_width))
ax.set_xticklabels(self.ds_val.df['deg'], rotation=90)
ax.set_ylim(0, 5)
# add legend
ax.legend()
# save plot
if 'plot_name' in self.args and self.args['plot_name'] != 'None':
save_path = self.args['output_dir'] if self.args['output_dir'][-1] == '/' else self.args['output_dir']+'/'
plt.savefig(save_path+'BarChart_' + self.args['plot_name'])
# display the plot
plt.tight_layout()
plt.show()


def _draw_linechart(self):
# settings of line chart
fig, ax = plt.subplots(figsize=(12, 8))
index = np.arange(len(self.ds_val.df['deg']))
colors = ['r', 'g', 'b', 'c', 'm']
for i, col in enumerate(['mos_pred', 'noi_pred', 'dis_pred', 'col_pred', 'loud_pred']):
ax.plot(index, self.ds_val.df[col], marker='o', linestyle='-', color=colors[i], label=col)
for x, y in zip(index, self.ds_val.df[col]):
ax.text(x, y, f'{y:.2f}', ha='left', va='top')
# set labels and title
ax.set_xlabel('WavName')
ax.set_ylabel('Scores')
title = 'LineChart for ' + str(self.args['deg']) if self.args['mode'] == 'predict_file' else 'Line Chart for Wavs under: ' + str(self.args['data_dir'])
ax.set_title(title)
ax.set_xticks(index)
ax.set_xticklabels(self.ds_val.df['deg'], rotation=90)
ax.set_ylim(0, 5)
# add legend
ax.legend()
# save plot
if 'plot_name' in self.args and self.args['plot_name'] != 'None':
save_path = self.args['output_dir'] if self.args['output_dir'][-1] == '/' else self.args['output_dir'] + '/'
plt.savefig(save_path + 'LineChart_' + self.args['plot_name'])
# display the plot
plt.tight_layout()
plt.show()

# calculate mean and standard deviation of the results and concatenate with the initial results.
def _compute_mean_stdDev(self):
mean = self.ds_val.df.drop(['deg', 'model'], axis=1).mean()
std = self.ds_val.df.drop(['deg', 'model'], axis=1).std()
# alter NAN to 0
std.fillna(0, inplace=True)

stat = pd.DataFrame({'deg': ['**Mean**', '**Standard Deviation**'],
'mos_pred': [mean['mos_pred'], std['mos_pred']],
'noi_pred': [mean['noi_pred'], std['noi_pred']],
'dis_pred': [mean['dis_pred'], std['dis_pred']],
'col_pred': [mean['col_pred'], std['col_pred']],
'loud_pred': [mean['loud_pred'], std['loud_pred']],
'model': self.ds_val.df['model'][0]})
stat = pd.concat([self.ds_val.df, stat])
return stat

def predict(self):
print('---> Predicting ...')
if self.args['tr_parallel']:
Expand All @@ -69,15 +144,38 @@ def predict(self):
self.ds_val,
self.args['tr_bs_val'],
self.dev,
num_workers=self.args['tr_num_workers'])
num_workers=self.args['tr_num_workers'])

if self.args['output_dir']:
self.ds_val.df['model'] = self.args['name']
self.ds_val.df.to_csv(
os.path.join(self.args['output_dir'], 'NISQA_results.csv'),
index=False)

print(self.ds_val.df.to_string(index=False))
csv_name = 'NISQA_results.csv'
if 'output_name' in self.args and self.args['output_name']:
csv_name = str(self.args['output_name']) + '.csv'
# whether print mean and standard deviation or not
if 'compute_stats' in self.args and self.args['compute_stats']:
# generate results with mean and deviation to self.statistics.
self.statistics = self._compute_mean_stdDev()
self.statistics.to_csv(
os.path.join(self.args['output_dir'], csv_name),
index=False)
else:
self.ds_val.df.to_csv(
os.path.join(self.args['output_dir'], csv_name),
index=False)

# print either statistics or ds_val based on 'compute_stats' parameter
if 'compute_stats' in self.args and self.args['compute_stats']:
print(self.statistics.to_string(index=False))
else:
print(self.ds_val.df.to_string(index=False))

# Visualization of the results
if 'plot_type' in self.args and self.args['plot_type'] == 'barchart':
self._draw_barchart()
elif 'plot_type' in self.args and self.args['plot_type'] == 'linechart':
self._draw_linechart()

# returned DataFrame is not changed.
return self.ds_val.df

def _train_mos(self):
Expand Down
17 changes: 17 additions & 0 deletions run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
parser.add_argument('--num_workers', type=int, default=0, help='number of workers for pytorchs dataloader')
parser.add_argument('--bs', type=int, default=1, help='batch size for predicting')
parser.add_argument('--ms_channel', type=int, help='audio channel in case of stereo file')
### Add
# 1) the name of result csv,
# 2) statistics(mean and standard deviation),
# 3) type of visualizations
# 4) whether to save the plot
### to the parameters
parser.add_argument('--output_name', type=str, help='name of the csv result file')
parser.add_argument('--compute_stats', action='store_true', help='whether to calculate the mean and the standard deviation of the results')
parser.add_argument('--plot_type', type=str, default='None', help='Visualization of the results. Either barchart, linechart or None')
parser.add_argument('--plot_name', type=str, default='None', help='name of the plot file if saving plot is needed')
###

args = parser.parse_args()
args = vars(args)
Expand All @@ -35,6 +46,12 @@
args['data_dir'] = ''
else:
raise NotImplementedError('--mode given not available')

# `plot_name` can only be set when `plot_type` is not `None``
if 'plot_name' in args and args['plot_name'] != 'None':
if 'plot_type' in args and args['plot_type'] == 'None':
raise ValueError('--plot_name argument can only be set when `plot_type` is either `barchart` or `linechart`')

args['tr_bs_val'] = args['bs']
args['tr_num_workers'] = args['num_workers']

Expand Down