Skip to content

Commit

Permalink
feat: Add descriptive statistics and charts
Browse files Browse the repository at this point in the history
Added descriptive tools for Speech Quality Prediction:
- Mean and Standard Deviation of the results
- BarChart and LineChart of the results
This change helps provide better assistance in Speech Quality Prediction, especially for
repeated averaging of results of the same test sample in practical Audio Testing Scenarios.
Also, output name of CSV is added just like gabrielmittag#30.
  • Loading branch information
Hadley-Zhang committed Sep 8, 2023
1 parent ac83137 commit 21b7c4f
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 8 deletions.
114 changes: 106 additions & 8 deletions nisqa/NISQA_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import yaml
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch import optim
from torch.utils.data import DataLoader
from . import NISQA_lib as NL
Expand Down Expand Up @@ -50,7 +51,81 @@ def evaluate(self, mapping='first_order', do_print=True, do_plot=False):
self._evaluate_dim(mapping=mapping, do_print=do_print, do_plot=do_plot)
else:
self._evaluate_mos(mapping=mapping, do_print=do_print, do_plot=do_plot)


def _draw_barchart(self):
# settings of barchart
bar_width = 0.15
index = np.arange(len(self.ds_val.df['deg']))
colors = ['r', 'g', 'b', 'c', 'm']
fig, ax = plt.subplots(figsize=(12, 8))
for i, col in enumerate(['mos_pred', 'noi_pred', 'dis_pred', 'col_pred', 'loud_pred']):
ax.bar(index + (i * bar_width), self.ds_val.df[col], bar_width, color=colors[i], label=col)
for j, val in enumerate(self.ds_val.df[col]):
ax.annotate(str(f'{val:.2f}'), xy=(index[j] + (i * bar_width), val), xytext=(0, 3),
textcoords="offset points", ha='center', va='bottom')
# set labels and title
ax.set_xlabel('WavName')
ax.set_ylabel('Scores')
title = 'BarChart for '+str(self.args['deg']) if self.args['mode'] == 'predict_file' else 'Barchart for Wavs under: '+str(self.args['data_dir'])
ax.set_title(title)
ax.set_xticks(index + (2 * bar_width))
ax.set_xticklabels(self.ds_val.df['deg'], rotation=90)
ax.set_ylim(0, 5)
# add legend
ax.legend()
# save plot
if 'plot_name' in self.args and self.args['plot_name'] != 'None':
save_path = self.args['output_dir'] if self.args['output_dir'][-1] == '/' else self.args['output_dir']+'/'
plt.savefig(save_path+'BarChart_' + self.args['plot_name'])
# display the plot
plt.tight_layout()
plt.show()


def _draw_linechart(self):
# settings of line chart
fig, ax = plt.subplots(figsize=(12, 8))
index = np.arange(len(self.ds_val.df['deg']))
colors = ['r', 'g', 'b', 'c', 'm']
for i, col in enumerate(['mos_pred', 'noi_pred', 'dis_pred', 'col_pred', 'loud_pred']):
ax.plot(index, self.ds_val.df[col], marker='o', linestyle='-', color=colors[i], label=col)
for x, y in zip(index, self.ds_val.df[col]):
ax.text(x, y, f'{y:.2f}', ha='left', va='top')
# set labels and title
ax.set_xlabel('WavName')
ax.set_ylabel('Scores')
title = 'LineChart for ' + str(self.args['deg']) if self.args['mode'] == 'predict_file' else 'Line Chart for Wavs under: ' + str(self.args['data_dir'])
ax.set_title(title)
ax.set_xticks(index)
ax.set_xticklabels(self.ds_val.df['deg'], rotation=90)
ax.set_ylim(0, 5)
# add legend
ax.legend()
# save plot
if 'plot_name' in self.args and self.args['plot_name'] != 'None':
save_path = self.args['output_dir'] if self.args['output_dir'][-1] == '/' else self.args['output_dir'] + '/'
plt.savefig(save_path + 'LineChart_' + self.args['plot_name'])
# display the plot
plt.tight_layout()
plt.show()

# calculate mean and standard deviation of the results and concatenate with the initial results.
def _compute_mean_stdDev(self):
mean = self.ds_val.df.drop(['deg', 'model'], axis=1).mean()
std = self.ds_val.df.drop(['deg', 'model'], axis=1).std()
# alter NAN to 0
std.fillna(0, inplace=True)

stat = pd.DataFrame({'deg': ['**Mean**', '**Standard Deviation**'],
'mos_pred': [mean['mos_pred'], std['mos_pred']],
'noi_pred': [mean['noi_pred'], std['noi_pred']],
'dis_pred': [mean['dis_pred'], std['dis_pred']],
'col_pred': [mean['col_pred'], std['col_pred']],
'loud_pred': [mean['loud_pred'], std['loud_pred']],
'model': self.ds_val.df['model'][0]})
stat = pd.concat([self.ds_val.df, stat])
return stat

def predict(self):
print('---> Predicting ...')
if self.args['tr_parallel']:
Expand All @@ -69,15 +144,38 @@ def predict(self):
self.ds_val,
self.args['tr_bs_val'],
self.dev,
num_workers=self.args['tr_num_workers'])
num_workers=self.args['tr_num_workers'])

if self.args['output_dir']:
self.ds_val.df['model'] = self.args['name']
self.ds_val.df.to_csv(
os.path.join(self.args['output_dir'], 'NISQA_results.csv'),
index=False)

print(self.ds_val.df.to_string(index=False))
csv_name = 'NISQA_results.csv'
if 'output_name' in self.args:
csv_name = str(self.args['output_name']) + '.csv'
# whether print mean and standard deviation or not
if 'compute_stats' in self.args and self.args['compute_stats']:
# generate results with mean and deviation to self.statistics.
self.statistics = self._compute_mean_stdDev()
self.statistics.to_csv(
os.path.join(self.args['output_dir'], csv_name),
index=False)
else:
self.ds_val.df.to_csv(
os.path.join(self.args['output_dir'], csv_name),
index=False)

# print either statistics or ds_val based on 'compute_stats' parameter
if 'compute_stats' in self.args and self.args['compute_stats']:
print(self.statistics.to_string(index=False))
else:
print(self.ds_val.df.to_string(index=False))

# Visualization of the results
if 'plot_type' in self.args and self.args['plot_type'] == 'barchart':
self._draw_barchart()
elif 'plot_type' in self.args and self.args['plot_type'] == 'linechart':
self._draw_linechart()

# returned DataFrame is not changed.
return self.ds_val.df

def _train_mos(self):
Expand Down
17 changes: 17 additions & 0 deletions run_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
parser.add_argument('--num_workers', type=int, default=0, help='number of workers for pytorchs dataloader')
parser.add_argument('--bs', type=int, default=1, help='batch size for predicting')
parser.add_argument('--ms_channel', type=int, help='audio channel in case of stereo file')
### Add
# 1) the name of result csv,
# 2) statistics(mean and standard deviation),
# 3) type of visualizations
# 4) whether to save the plot
### to the parameters
parser.add_argument('--output_name', type=str, help='name of the csv result file')
parser.add_argument('--compute_stats', action='store_true', help='whether to calculate the mean and the standard deviation of the results')
parser.add_argument('--plot_type', type=str, default='None', help='Visualization of the results. Either barchart, linechart or None')
parser.add_argument('--plot_name', type=str, default='None', help='name of the plot file if saving plot is needed')
###

args = parser.parse_args()
args = vars(args)
Expand All @@ -35,6 +46,12 @@
args['data_dir'] = ''
else:
raise NotImplementedError('--mode given not available')

# `plot_name` can only be set when `plot_type` is not `None``
if 'plot_name' in args and args['plot_name'] != 'None':
if 'plot_type' in args and args['plot_type'] == 'None':
raise ValueError('--plot_name argument can only be set when `plot_type` is either `barchart` or `linechart`')

args['tr_bs_val'] = args['bs']
args['tr_num_workers'] = args['num_workers']

Expand Down

0 comments on commit 21b7c4f

Please sign in to comment.