diff --git a/wekws/bin/train.py b/wekws/bin/train.py index c88ea1f..025ebf9 100644 --- a/wekws/bin/train.py +++ b/wekws/bin/train.py @@ -141,15 +141,14 @@ def main(): configs['model']['cmvn'] = {} configs['model']['cmvn']['norm_var'] = args.norm_var configs['model']['cmvn']['cmvn_file'] = args.cmvn_file + # Init asr model from configs + model = init_model(configs['model']) if rank == 0: saved_config_path = os.path.join(args.model_dir, 'config.yaml') with open(saved_config_path, 'w') as fout: data = yaml.dump(configs) fout.write(data) - - # Init asr model from configs - model = init_model(configs['model']) - print(model) + print(model) num_params = count_parameters(model) print('the number of model params: {}'.format(num_params))