mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
399 lines
14 KiB
Python
399 lines
14 KiB
Python
|
#!/usr/bin/env python3
|
||
|
|
||
|
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||
|
|
||
|
"""Common functions for ASR."""
|
||
|
|
||
|
import argparse
|
||
|
import editdistance
|
||
|
import json
|
||
|
import logging
|
||
|
import numpy as np
|
||
|
import six
|
||
|
import sys
|
||
|
|
||
|
from itertools import groupby
|
||
|
|
||
|
|
||
|
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||
|
"""End detection.
|
||
|
|
||
|
desribed in Eq. (50) of S. Watanabe et al
|
||
|
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||
|
|
||
|
:param ended_hyps:
|
||
|
:param i:
|
||
|
:param M:
|
||
|
:param D_end:
|
||
|
:return:
|
||
|
"""
|
||
|
if len(ended_hyps) == 0:
|
||
|
return False
|
||
|
count = 0
|
||
|
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
|
||
|
for m in six.moves.range(M):
|
||
|
# get ended_hyps with their length is i - m
|
||
|
hyp_length = i - m
|
||
|
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
|
||
|
if len(hyps_same_length) > 0:
|
||
|
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
|
||
|
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
|
||
|
count += 1
|
||
|
|
||
|
if count == M:
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
|
||
|
# TODO(takaaki-hori): add different smoothing methods
|
||
|
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||
|
"""Obtain label distribution for loss smoothing.
|
||
|
|
||
|
:param odim:
|
||
|
:param lsm_type:
|
||
|
:param blank:
|
||
|
:param transcript:
|
||
|
:return:
|
||
|
"""
|
||
|
if transcript is not None:
|
||
|
with open(transcript, 'rb') as f:
|
||
|
trans_json = json.load(f)['utts']
|
||
|
|
||
|
if lsm_type == 'unigram':
|
||
|
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
|
||
|
labelcount = np.zeros(odim)
|
||
|
for k, v in trans_json.items():
|
||
|
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
|
||
|
# to avoid an error when there is no text in an uttrance
|
||
|
if len(ids) > 0:
|
||
|
labelcount[ids] += 1
|
||
|
labelcount[odim - 1] = len(transcript) # count <eos>
|
||
|
labelcount[labelcount == 0] = 1 # flooring
|
||
|
labelcount[blank] = 0 # remove counts for blank
|
||
|
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||
|
else:
|
||
|
logging.error(
|
||
|
"Error: unexpected label smoothing type: %s" % lsm_type)
|
||
|
sys.exit()
|
||
|
|
||
|
return labeldist
|
||
|
|
||
|
|
||
|
def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True):
|
||
|
"""Return the output size of the VGG frontend.
|
||
|
|
||
|
:param in_channel: input channel size
|
||
|
:param out_channel: output channel size
|
||
|
:return: output size
|
||
|
:rtype int
|
||
|
"""
|
||
|
idim = idim / in_channel
|
||
|
if downsample:
|
||
|
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||
|
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||
|
return int(idim) * out_channel # numer of channels
|
||
|
|
||
|
|
||
|
class ErrorCalculator(object):
|
||
|
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||
|
|
||
|
:param y_hats: numpy array with predicted text
|
||
|
:param y_pads: numpy array with true (target) text
|
||
|
:param char_list:
|
||
|
:param sym_space:
|
||
|
:param sym_blank:
|
||
|
:return:
|
||
|
"""
|
||
|
|
||
|
def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False,
|
||
|
trans_type="char"):
|
||
|
"""Construct an ErrorCalculator object."""
|
||
|
super(ErrorCalculator, self).__init__()
|
||
|
|
||
|
self.report_cer = report_cer
|
||
|
self.report_wer = report_wer
|
||
|
self.trans_type = trans_type
|
||
|
self.char_list = char_list
|
||
|
self.space = sym_space
|
||
|
self.blank = sym_blank
|
||
|
self.idx_blank = self.char_list.index(self.blank)
|
||
|
if self.space in self.char_list:
|
||
|
self.idx_space = self.char_list.index(self.space)
|
||
|
else:
|
||
|
self.idx_space = None
|
||
|
|
||
|
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||
|
"""Calculate sentence-level WER/CER score.
|
||
|
|
||
|
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||
|
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||
|
:param bool is_ctc: calculate CER score for CTC
|
||
|
:return: sentence-level WER score
|
||
|
:rtype float
|
||
|
:return: sentence-level CER score
|
||
|
:rtype float
|
||
|
"""
|
||
|
cer, wer = None, None
|
||
|
if is_ctc:
|
||
|
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||
|
elif not self.report_cer and not self.report_wer:
|
||
|
return cer, wer
|
||
|
|
||
|
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||
|
if self.report_cer:
|
||
|
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||
|
|
||
|
if self.report_wer:
|
||
|
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||
|
return cer, wer
|
||
|
|
||
|
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||
|
"""Calculate sentence-level CER score for CTC.
|
||
|
|
||
|
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||
|
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||
|
:return: average sentence-level CER score
|
||
|
:rtype float
|
||
|
"""
|
||
|
cers, char_ref_lens = [], []
|
||
|
for i, y in enumerate(ys_hat):
|
||
|
y_hat = [x[0] for x in groupby(y)]
|
||
|
y_true = ys_pad[i]
|
||
|
seq_hat, seq_true = [], []
|
||
|
for idx in y_hat:
|
||
|
idx = int(idx)
|
||
|
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||
|
seq_hat.append(self.char_list[int(idx)])
|
||
|
|
||
|
for idx in y_true:
|
||
|
idx = int(idx)
|
||
|
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||
|
seq_true.append(self.char_list[int(idx)])
|
||
|
if self.trans_type == "char":
|
||
|
hyp_chars = "".join(seq_hat)
|
||
|
ref_chars = "".join(seq_true)
|
||
|
else:
|
||
|
hyp_chars = " ".join(seq_hat)
|
||
|
ref_chars = " ".join(seq_true)
|
||
|
|
||
|
if len(ref_chars) > 0:
|
||
|
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||
|
char_ref_lens.append(len(ref_chars))
|
||
|
|
||
|
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||
|
return cer_ctc
|
||
|
|
||
|
def convert_to_char(self, ys_hat, ys_pad):
|
||
|
"""Convert index to character.
|
||
|
|
||
|
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||
|
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||
|
:return: token list of prediction
|
||
|
:rtype list
|
||
|
:return: token list of reference
|
||
|
:rtype list
|
||
|
"""
|
||
|
seqs_hat, seqs_true = [], []
|
||
|
for i, y_hat in enumerate(ys_hat):
|
||
|
y_true = ys_pad[i]
|
||
|
eos_true = np.where(y_true == -1)[0]
|
||
|
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||
|
# To avoid wrong higher WER than the one obtained from the decoding
|
||
|
# eos from y_true is used to mark the eos in y_hat
|
||
|
# because of that y_hats has not padded outs with -1.
|
||
|
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||
|
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||
|
# seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||
|
seq_hat_text = " ".join(seq_hat).replace(self.space, ' ')
|
||
|
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||
|
# seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||
|
seq_true_text = " ".join(seq_true).replace(self.space, ' ')
|
||
|
seqs_hat.append(seq_hat_text)
|
||
|
seqs_true.append(seq_true_text)
|
||
|
return seqs_hat, seqs_true
|
||
|
|
||
|
def calculate_cer(self, seqs_hat, seqs_true):
|
||
|
"""Calculate sentence-level CER score.
|
||
|
|
||
|
:param list seqs_hat: prediction
|
||
|
:param list seqs_true: reference
|
||
|
:return: average sentence-level CER score
|
||
|
:rtype float
|
||
|
"""
|
||
|
char_eds, char_ref_lens = [], []
|
||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||
|
seq_true_text = seqs_true[i]
|
||
|
hyp_chars = seq_hat_text.replace(' ', '')
|
||
|
ref_chars = seq_true_text.replace(' ', '')
|
||
|
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||
|
char_ref_lens.append(len(ref_chars))
|
||
|
return float(sum(char_eds)) / sum(char_ref_lens)
|
||
|
|
||
|
def calculate_wer(self, seqs_hat, seqs_true):
|
||
|
"""Calculate sentence-level WER score.
|
||
|
|
||
|
:param list seqs_hat: prediction
|
||
|
:param list seqs_true: reference
|
||
|
:return: average sentence-level WER score
|
||
|
:rtype float
|
||
|
"""
|
||
|
word_eds, word_ref_lens = [], []
|
||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||
|
seq_true_text = seqs_true[i]
|
||
|
hyp_words = seq_hat_text.split()
|
||
|
ref_words = seq_true_text.split()
|
||
|
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||
|
word_ref_lens.append(len(ref_words))
|
||
|
return float(sum(word_eds)) / sum(word_ref_lens)
|
||
|
|
||
|
|
||
|
class ErrorCalculatorTrans(object):
|
||
|
"""Calculate CER and WER for transducer models.
|
||
|
|
||
|
Args:
|
||
|
decoder (nn.Module): decoder module
|
||
|
args (Namespace): argument Namespace containing options
|
||
|
report_cer (boolean): compute CER option
|
||
|
report_wer (boolean): compute WER option
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, decoder, args, report_cer=False, report_wer=False):
|
||
|
"""Construct an ErrorCalculator object for transducer model."""
|
||
|
super(ErrorCalculatorTrans, self).__init__()
|
||
|
|
||
|
self.dec = decoder
|
||
|
|
||
|
recog_args = {'beam_size': args.beam_size,
|
||
|
'nbest': args.nbest,
|
||
|
'space': args.sym_space,
|
||
|
'score_norm_transducer': args.score_norm_transducer}
|
||
|
|
||
|
self.recog_args = argparse.Namespace(**recog_args)
|
||
|
|
||
|
self.char_list = args.char_list
|
||
|
self.space = args.sym_space
|
||
|
self.blank = args.sym_blank
|
||
|
|
||
|
self.report_cer = args.report_cer
|
||
|
self.report_wer = args.report_wer
|
||
|
|
||
|
def __call__(self, hs_pad, ys_pad):
|
||
|
"""Calculate sentence-level WER/CER score for transducer models.
|
||
|
|
||
|
Args:
|
||
|
hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D)
|
||
|
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||
|
|
||
|
Returns:
|
||
|
(float): sentence-level CER score
|
||
|
(float): sentence-level WER score
|
||
|
|
||
|
"""
|
||
|
cer, wer = None, None
|
||
|
|
||
|
if not self.report_cer and not self.report_wer:
|
||
|
return cer, wer
|
||
|
|
||
|
batchsize = int(hs_pad.size(0))
|
||
|
batch_nbest = []
|
||
|
|
||
|
for b in six.moves.range(batchsize):
|
||
|
if self.recog_args.beam_size == 1:
|
||
|
nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args)
|
||
|
else:
|
||
|
nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
|
||
|
batch_nbest.append(nbest_hyps)
|
||
|
|
||
|
ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]
|
||
|
|
||
|
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu())
|
||
|
|
||
|
if self.report_cer:
|
||
|
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||
|
|
||
|
if self.report_wer:
|
||
|
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||
|
|
||
|
return cer, wer
|
||
|
|
||
|
def convert_to_char(self, ys_hat, ys_pad):
|
||
|
"""Convert index to character.
|
||
|
|
||
|
Args:
|
||
|
ys_hat (torch.Tensor): prediction (batch, seqlen)
|
||
|
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||
|
|
||
|
Returns:
|
||
|
(list): token list of prediction
|
||
|
(list): token list of reference
|
||
|
|
||
|
"""
|
||
|
seqs_hat, seqs_true = [], []
|
||
|
|
||
|
for i, y_hat in enumerate(ys_hat):
|
||
|
y_true = ys_pad[i]
|
||
|
|
||
|
eos_true = np.where(y_true == -1)[0]
|
||
|
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||
|
|
||
|
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||
|
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||
|
|
||
|
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||
|
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||
|
seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||
|
|
||
|
seqs_hat.append(seq_hat_text)
|
||
|
seqs_true.append(seq_true_text)
|
||
|
|
||
|
return seqs_hat, seqs_true
|
||
|
|
||
|
def calculate_cer(self, seqs_hat, seqs_true):
|
||
|
"""Calculate sentence-level CER score for transducer model.
|
||
|
|
||
|
Args:
|
||
|
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||
|
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||
|
|
||
|
Returns:
|
||
|
(float): average sentence-level CER score
|
||
|
|
||
|
"""
|
||
|
char_eds, char_ref_lens = [], []
|
||
|
|
||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||
|
seq_true_text = seqs_true[i]
|
||
|
hyp_chars = seq_hat_text.replace(' ', '')
|
||
|
ref_chars = seq_true_text.replace(' ', '')
|
||
|
|
||
|
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||
|
char_ref_lens.append(len(ref_chars))
|
||
|
|
||
|
return float(sum(char_eds)) / sum(char_ref_lens)
|
||
|
|
||
|
def calculate_wer(self, seqs_hat, seqs_true):
|
||
|
"""Calculate sentence-level WER score for transducer model.
|
||
|
|
||
|
Args:
|
||
|
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||
|
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||
|
|
||
|
Returns:
|
||
|
(float): average sentence-level WER score
|
||
|
|
||
|
"""
|
||
|
word_eds, word_ref_lens = [], []
|
||
|
|
||
|
for i, seq_hat_text in enumerate(seqs_hat):
|
||
|
seq_true_text = seqs_true[i]
|
||
|
hyp_words = seq_hat_text.split()
|
||
|
ref_words = seq_true_text.split()
|
||
|
|
||
|
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||
|
word_ref_lens.append(len(ref_words))
|
||
|
|
||
|
return float(sum(word_eds)) / sum(word_ref_lens)
|