diff --git a/.gitignore b/.gitignore index 70a9b93..7df88c7 100644 --- a/.gitignore +++ b/.gitignore @@ -15,9 +15,8 @@ *.toc *.wav *.sh -synthesizer/saved_models/* -vocoder/saved_models/* -encoder/saved_models/* -cp_hifigan/* -!vocoder/saved_models/pretrained/* -!encoder/saved_models/pretrained.pt \ No newline at end of file +*/saved_models +!vocoder/saved_models/pretrained/** +!encoder/saved_models/pretrained.pt +wavs +log \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json index 3b264f6..23e5203 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -35,6 +35,14 @@ "console": "integratedTerminal", "args": ["-d","..\\audiodata"] }, + { + "name": "Python: Demo Box VC", + "type": "python", + "request": "launch", + "program": "demo_toolbox.py", + "console": "integratedTerminal", + "args": ["-d","..\\audiodata","-vc"] + }, { "name": "Python: Synth Train", "type": "python", @@ -43,5 +51,15 @@ "console": "integratedTerminal", "args": ["my_run", "..\\"] }, + { + "name": "Python: PPG Convert", + "type": "python", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml", + "-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\" + ] + }, ] } diff --git a/demo_toolbox.py b/demo_toolbox.py index d938031..7030bd5 100644 --- a/demo_toolbox.py +++ b/demo_toolbox.py @@ -15,12 +15,18 @@ if __name__ == '__main__': parser.add_argument("-d", "--datasets_root", type=Path, help= \ "Path to the directory containing your datasets. See toolbox/__init__.py for a list of " "supported datasets.", default=None) + parser.add_argument("-vc", "--vc_mode", action="store_true", + help="Voice Conversion Mode(PPG based)") parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models", help="Directory containing saved encoder models") parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models", help="Directory containing saved synthesizer models") parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models", help="Directory containing saved vocoder models") + parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models", + help="Directory containing saved extrator models") + parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models", + help="Directory containing saved convert models") parser.add_argument("--cpu", action="store_true", help=\ "If True, processing is done on CPU, even when a GPU is available.") parser.add_argument("--seed", type=int, default=None, help=\ diff --git a/encoder/inference.py b/encoder/inference.py index 4ca417b..af9a529 100644 --- a/encoder/inference.py +++ b/encoder/inference.py @@ -34,8 +34,16 @@ def load_model(weights_fpath: Path, device=None): _model.load_state_dict(checkpoint["model_state"]) _model.eval() print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"])) + return _model - +def set_model(model, device=None): + global _model, _device + _model = model + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + _device = device + _model.to(device) + def is_loaded(): return _model is not None @@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch): def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames, - min_pad_coverage=0.75, overlap=0.5): + min_pad_coverage=0.75, overlap=0.5, rate=None): """ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain partial utterances of each. Both the waveform and the mel @@ -85,9 +93,18 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram assert 0 <= overlap < 1 assert 0 < min_pad_coverage <= 1 - samples_per_frame = int((sampling_rate * mel_window_step / 1000)) - n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) - frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) + if rate != None: + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = int(np.round((sampling_rate / rate) / samples_per_frame)) + else: + samples_per_frame = int((sampling_rate * mel_window_step / 1000)) + n_frames = int(np.ceil((n_samples + 1) / samples_per_frame)) + frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1) + + assert 0 < frame_step, "The rate is too high" + assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \ + (sampling_rate / (samples_per_frame * partials_n_frames)) # Compute the slices wav_slices, mel_slices = [], [] diff --git a/ppg2mel/__init__.py b/ppg2mel/__init__.py new file mode 100644 index 0000000..53ee3b2 --- /dev/null +++ b/ppg2mel/__init__.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Songxiang Liu +# Apache 2.0 + +from typing import List + +import torch +import torch.nn.functional as F + +import numpy as np + +from .utils.abs_model import AbsMelDecoder +from .rnn_decoder_mol import Decoder +from .utils.cnn_postnet import Postnet +from .utils.vc_utils import get_mask_from_lengths + +from utils.load_yaml import HpsYaml + +class MelDecoderMOLv2(AbsMelDecoder): + """Use an encoder to preprocess ppg.""" + def __init__( + self, + num_speakers: int, + spk_embed_dim: int, + bottle_neck_feature_dim: int, + encoder_dim: int = 256, + encoder_downsample_rates: List = [2, 2], + attention_rnn_dim: int = 512, + decoder_rnn_dim: int = 512, + num_decoder_rnn_layer: int = 1, + concat_context_to_last: bool = True, + prenet_dims: List = [256, 128], + num_mixtures: int = 5, + frames_per_step: int = 2, + mask_padding: bool = True, + ): + super().__init__() + + self.mask_padding = mask_padding + self.bottle_neck_feature_dim = bottle_neck_feature_dim + self.num_mels = 80 + self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1] + self.frames_per_step = frames_per_step + self.use_spk_dvec = True + + input_dim = bottle_neck_feature_dim + + # Downsampling convolution + self.bnf_prenet = torch.nn.Sequential( + torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[0], + stride=encoder_downsample_rates[0], + padding=encoder_downsample_rates[0]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[1], + stride=encoder_downsample_rates[1], + padding=encoder_downsample_rates[1]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + ) + decoder_enc_dim = encoder_dim + self.pitch_convs = torch.nn.Sequential( + torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[0], + stride=encoder_downsample_rates[0], + padding=encoder_downsample_rates[0]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + torch.nn.Conv1d( + encoder_dim, encoder_dim, + kernel_size=2*encoder_downsample_rates[1], + stride=encoder_downsample_rates[1], + padding=encoder_downsample_rates[1]//2, + ), + torch.nn.LeakyReLU(0.1), + + torch.nn.InstanceNorm1d(encoder_dim, affine=False), + ) + + self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim) + + # Decoder + self.decoder = Decoder( + enc_dim=decoder_enc_dim, + num_mels=self.num_mels, + frames_per_step=frames_per_step, + attention_rnn_dim=attention_rnn_dim, + decoder_rnn_dim=decoder_rnn_dim, + num_decoder_rnn_layer=num_decoder_rnn_layer, + prenet_dims=prenet_dims, + num_mixtures=num_mixtures, + use_stop_tokens=True, + concat_context_to_last=concat_context_to_last, + encoder_down_factor=self.encoder_down_factor, + ) + + # Mel-Spec Postnet: some residual CNN layers + self.postnet = Postnet() + + def parse_output(self, outputs, output_lengths=None): + if self.mask_padding and output_lengths is not None: + mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1)) + mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels) + outputs[0].data.masked_fill_(mask, 0.0) + outputs[1].data.masked_fill_(mask, 0.0) + return outputs + + def forward( + self, + bottle_neck_features: torch.Tensor, + feature_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + logf0_uv: torch.Tensor = None, + spembs: torch.Tensor = None, + output_att_ws: bool = False, + ): + decoder_inputs = self.bnf_prenet( + bottle_neck_features.transpose(1, 2) + ).transpose(1, 2) + logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2) + decoder_inputs = decoder_inputs + logf0_uv + + assert spembs is not None + spk_embeds = F.normalize( + spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1) + decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1) + decoder_inputs = self.reduce_proj(decoder_inputs) + + # (B, num_mels, T_dec) + T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor') + mel_outputs, predicted_stop, alignments = self.decoder( + decoder_inputs, speech, T_dec) + ## Post-processing + mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + if output_att_ws: + return self.parse_output( + [mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths) + else: + return self.parse_output( + [mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths) + + # return mel_outputs, mel_outputs_postnet + + def inference( + self, + bottle_neck_features: torch.Tensor, + logf0_uv: torch.Tensor = None, + spembs: torch.Tensor = None, + ): + decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2) + logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2) + decoder_inputs = decoder_inputs + logf0_uv + + assert spembs is not None + spk_embeds = F.normalize( + spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1) + bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1) + bottle_neck_features = self.reduce_proj(bottle_neck_features) + + ## Decoder + if bottle_neck_features.size(0) > 1: + mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features) + else: + mel_outputs, alignments = self.decoder.inference(bottle_neck_features,) + ## Post-processing + mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + # outputs = mel_outputs_postnet[0] + + return mel_outputs[0], mel_outputs_postnet[0], alignments[0] + +def load_model(train_config, model_file, device=None): + + if device is None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model_config = HpsYaml(train_config) + ppg2mel_model = MelDecoderMOLv2( + **model_config["model"] + ).to(device) + ckpt = torch.load(model_file, map_location=device) + ppg2mel_model.load_state_dict(ckpt["model"]) + ppg2mel_model.eval() + return ppg2mel_model diff --git a/ppg2mel/preprocess.py b/ppg2mel/preprocess.py new file mode 100644 index 0000000..6da9054 --- /dev/null +++ b/ppg2mel/preprocess.py @@ -0,0 +1,112 @@ + +import os +import torch +import numpy as np +from tqdm import tqdm +from pathlib import Path +import soundfile +import resampy + +from ppg_extractor import load_model +import encoder.inference as Encoder +from encoder.audio import preprocess_wav +from encoder import audio +from utils.f0_utils import compute_f0 + +from torch.multiprocessing import Pool, cpu_count +from functools import partial + +SAMPLE_RATE=16000 + +def _compute_bnf( + wav: any, + output_fpath: str, + device: torch.device, + ppg_model_local: any, +): + """ + Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF). + """ + ppg_model_local.to(device) + wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0) + wav_length = torch.LongTensor([wav.shape[0]]).to(device) + with torch.no_grad(): + bnf = ppg_model_local(wav_tensor, wav_length) + bnf_npy = bnf.squeeze(0).cpu().numpy() + np.save(output_fpath, bnf_npy, allow_pickle=False) + return bnf_npy, len(bnf_npy) + +def _compute_f0_from_wav(wav, output_fpath): + """Compute merged f0 values.""" + f0 = compute_f0(wav, SAMPLE_RATE) + np.save(output_fpath, f0, allow_pickle=False) + return f0, len(f0) + +def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device): + Encoder.set_model(encoder_model_local) + # Compute where to split the utterance into partials and pad if necessary + wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75) + max_wave_length = wave_slices[-1].stop + if max_wave_length >= len(wav): + wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant") + + # Split the utterance into partials + frames = audio.wav_to_mel_spectrogram(wav) + frames_batch = np.array([frames[s] for s in mel_slices]) + partial_embeds = Encoder.embed_frames_batch(frames_batch) + + # Compute the utterance embedding from the partial embeddings + raw_embed = np.mean(partial_embeds, axis=0) + embed = raw_embed / np.linalg.norm(raw_embed, 2) + + np.save(output_fpath, embed, allow_pickle=False) + return embed, len(embed) + +def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local): + # wav = preprocess_wav(wav_path) + # try: + wav, sr = soundfile.read(wav_path) + if len(wav) < sr: + return None, sr, len(wav) + if sr != SAMPLE_RATE: + wav = resampy.resample(wav, sr, SAMPLE_RATE) + sr = SAMPLE_RATE + utt_id = os.path.basename(wav_path).rstrip(".wav") + + _, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local) + _, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav) + _, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav) + +def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model): + # Glob wav files + wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav")) + print(f"Globbed {len(wav_file_list)} wav files.") + + out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True) + out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True) + out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True) + ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu") + encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu") + if n_processes is None: + n_processes = cpu_count() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device) + job = Pool(n_processes).imap(func, wav_file_list) + list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav")) + + # finish processing and mark + t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8") + d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8") + e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8") + for file in sorted(out_dir.joinpath("f0").glob("*.npy")): + id = os.path.basename(file).split(".f0.npy")[0] + if id.endswith("01"): + d_fid_file.write(id + "\n") + elif id.endswith("09"): + e_fid_file.write(id + "\n") + else: + t_fid_file.write(id + "\n") + t_fid_file.close() + d_fid_file.close() + e_fid_file.close() diff --git a/ppg2mel/rnn_decoder_mol.py b/ppg2mel/rnn_decoder_mol.py new file mode 100644 index 0000000..9d48d7b --- /dev/null +++ b/ppg2mel/rnn_decoder_mol.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from .utils.mol_attention import MOLAttention +from .utils.basic_layers import Linear +from .utils.vc_utils import get_mask_from_lengths + + +class DecoderPrenet(nn.Module): + def __init__(self, in_dim, sizes): + super().__init__() + in_sizes = [in_dim] + sizes[:-1] + self.layers = nn.ModuleList( + [Linear(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, sizes)]) + + def forward(self, x): + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class Decoder(nn.Module): + """Mixture of Logistic (MoL) attention-based RNN Decoder.""" + def __init__( + self, + enc_dim, + num_mels, + frames_per_step, + attention_rnn_dim, + decoder_rnn_dim, + prenet_dims, + num_mixtures, + encoder_down_factor=1, + num_decoder_rnn_layer=1, + use_stop_tokens=False, + concat_context_to_last=False, + ): + super().__init__() + self.enc_dim = enc_dim + self.encoder_down_factor = encoder_down_factor + self.num_mels = num_mels + self.frames_per_step = frames_per_step + self.attention_rnn_dim = attention_rnn_dim + self.decoder_rnn_dim = decoder_rnn_dim + self.prenet_dims = prenet_dims + self.use_stop_tokens = use_stop_tokens + self.num_decoder_rnn_layer = num_decoder_rnn_layer + self.concat_context_to_last = concat_context_to_last + + # Mel prenet + self.prenet = DecoderPrenet(num_mels, prenet_dims) + self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims) + + # Attention RNN + self.attention_rnn = nn.LSTMCell( + prenet_dims[-1] + enc_dim, + attention_rnn_dim + ) + + # Attention + self.attention_layer = MOLAttention( + attention_rnn_dim, + r=frames_per_step/encoder_down_factor, + M=num_mixtures, + ) + + # Decoder RNN + self.decoder_rnn_layers = nn.ModuleList() + for i in range(num_decoder_rnn_layer): + if i == 0: + self.decoder_rnn_layers.append( + nn.LSTMCell( + enc_dim + attention_rnn_dim, + decoder_rnn_dim)) + else: + self.decoder_rnn_layers.append( + nn.LSTMCell( + decoder_rnn_dim, + decoder_rnn_dim)) + # self.decoder_rnn = nn.LSTMCell( + # 2 * enc_dim + attention_rnn_dim, + # decoder_rnn_dim + # ) + if concat_context_to_last: + self.linear_projection = Linear( + enc_dim + decoder_rnn_dim, + num_mels * frames_per_step + ) + else: + self.linear_projection = Linear( + decoder_rnn_dim, + num_mels * frames_per_step + ) + + + # Stop-token layer + if self.use_stop_tokens: + if concat_context_to_last: + self.stop_layer = Linear( + enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid" + ) + else: + self.stop_layer = Linear( + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid" + ) + + + def get_go_frame(self, memory): + B = memory.size(0) + go_frame = torch.zeros((B, self.num_mels), dtype=torch.float, + device=memory.device) + return go_frame + + def initialize_decoder_states(self, memory, mask): + device = next(self.parameters()).device + B = memory.size(0) + + # attention rnn states + self.attention_hidden = torch.zeros( + (B, self.attention_rnn_dim), device=device) + self.attention_cell = torch.zeros( + (B, self.attention_rnn_dim), device=device) + + # decoder rnn states + self.decoder_hiddens = [] + self.decoder_cells = [] + for i in range(self.num_decoder_rnn_layer): + self.decoder_hiddens.append( + torch.zeros((B, self.decoder_rnn_dim), + device=device) + ) + self.decoder_cells.append( + torch.zeros((B, self.decoder_rnn_dim), + device=device) + ) + # self.decoder_hidden = torch.zeros( + # (B, self.decoder_rnn_dim), device=device) + # self.decoder_cell = torch.zeros( + # (B, self.decoder_rnn_dim), device=device) + + self.attention_context = torch.zeros( + (B, self.enc_dim), device=device) + + self.memory = memory + # self.processed_memory = self.attention_layer.memory_layer(memory) + self.mask = mask + + def parse_decoder_inputs(self, decoder_inputs): + """Prepare decoder inputs, i.e. gt mel + Args: + decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training. + """ + decoder_inputs = decoder_inputs.reshape( + decoder_inputs.size(0), + int(decoder_inputs.size(1)/self.frames_per_step), -1) + # (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels) + decoder_inputs = decoder_inputs.transpose(0, 1) + # (T_out//r, B, num_mels) + decoder_inputs = decoder_inputs[:,:,-self.num_mels:] + return decoder_inputs + + def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs): + """ Prepares decoder outputs for output + Args: + mel_outputs: + alignments: + """ + # (T_out//r, B, T_enc) -> (B, T_out//r, T_enc) + alignments = torch.stack(alignments).transpose(0, 1) + # (T_out//r, B) -> (B, T_out//r) + if stop_outputs is not None: + if alignments.size(0) == 1: + stop_outputs = torch.stack(stop_outputs).unsqueeze(0) + else: + stop_outputs = torch.stack(stop_outputs).transpose(0, 1) + stop_outputs = stop_outputs.contiguous() + # (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r) + mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() + # decouple frames per step + # (B, T_out, num_mels) + mel_outputs = mel_outputs.view( + mel_outputs.size(0), -1, self.num_mels) + return mel_outputs, alignments, stop_outputs + + def attend(self, decoder_input): + cell_input = torch.cat((decoder_input, self.attention_context), -1) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_context, attention_weights = self.attention_layer( + self.attention_hidden, self.memory, None, self.mask) + + decoder_rnn_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) + + return decoder_rnn_input, self.attention_context, attention_weights + + def decode(self, decoder_input): + for i in range(self.num_decoder_rnn_layer): + if i == 0: + self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i]( + decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i])) + else: + self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i]( + self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i])) + return self.decoder_hiddens[-1] + + def forward(self, memory, mel_inputs, memory_lengths): + """ Decoder forward pass for training + Args: + memory: (B, T_enc, enc_dim) Encoder outputs + decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing. + memory_lengths: (B, ) Encoder output lengths for attention masking. + Returns: + mel_outputs: (B, T, num_mels) mel outputs from the decoder + alignments: (B, T//r, T_enc) attention weights. + """ + # [1, B, num_mels] + go_frame = self.get_go_frame(memory).unsqueeze(0) + # [T//r, B, num_mels] + mel_inputs = self.parse_decoder_inputs(mel_inputs) + # [T//r + 1, B, num_mels] + mel_inputs = torch.cat((go_frame, mel_inputs), dim=0) + # [T//r + 1, B, prenet_dim] + decoder_inputs = self.prenet(mel_inputs) + # decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__) + + self.initialize_decoder_states( + memory, mask=~get_mask_from_lengths(memory_lengths), + ) + + self.attention_layer.init_states(memory) + # self.attention_layer_pitch.init_states(memory_pitch) + + mel_outputs, alignments = [], [] + if self.use_stop_tokens: + stop_outputs = [] + else: + stop_outputs = None + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] + # decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)] + + decoder_rnn_input, context, attention_weights = self.attend(decoder_input) + + decoder_rnn_output = self.decode(decoder_rnn_input) + if self.concat_context_to_last: + decoder_rnn_output = torch.cat( + (decoder_rnn_output, context), dim=1) + + mel_output = self.linear_projection(decoder_rnn_output) + if self.use_stop_tokens: + stop_output = self.stop_layer(decoder_rnn_output) + stop_outputs += [stop_output.squeeze()] + mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze + alignments += [attention_weights] + # alignments_pitch += [attention_weights_pitch] + + mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs( + mel_outputs, alignments, stop_outputs) + if stop_outputs is None: + return mel_outputs, alignments + else: + return mel_outputs, stop_outputs, alignments + + def inference(self, memory, stop_threshold=0.5): + """ Decoder inference + Args: + memory: (1, T_enc, D_enc) Encoder outputs + Returns: + mel_outputs: mel outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + # [1, num_mels] + decoder_input = self.get_go_frame(memory) + + self.initialize_decoder_states(memory, mask=None) + + self.attention_layer.init_states(memory) + + mel_outputs, alignments = [], [] + # NOTE(sx): heuristic + max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step + min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5 + while True: + decoder_input = self.prenet(decoder_input) + + decoder_input_final, context, alignment = self.attend(decoder_input) + + #mel_output, stop_output, alignment = self.decode(decoder_input) + decoder_rnn_output = self.decode(decoder_input_final) + if self.concat_context_to_last: + decoder_rnn_output = torch.cat( + (decoder_rnn_output, context), dim=1) + + mel_output = self.linear_projection(decoder_rnn_output) + stop_output = self.stop_layer(decoder_rnn_output) + + mel_outputs += [mel_output.squeeze(1)] + alignments += [alignment] + + if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step: + break + if len(mel_outputs) >= max_decoder_step: + # print("Warning! Decoding steps reaches max decoder steps.") + break + + decoder_input = mel_output[:,-self.num_mels:] + + + mel_outputs, alignments, _ = self.parse_decoder_outputs( + mel_outputs, alignments, None) + + return mel_outputs, alignments + + def inference_batched(self, memory, stop_threshold=0.5): + """ Decoder inference + Args: + memory: (B, T_enc, D_enc) Encoder outputs + Returns: + mel_outputs: mel outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + # [1, num_mels] + decoder_input = self.get_go_frame(memory) + + self.initialize_decoder_states(memory, mask=None) + + self.attention_layer.init_states(memory) + + mel_outputs, alignments = [], [] + stop_outputs = [] + # NOTE(sx): heuristic + max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step + min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5 + while True: + decoder_input = self.prenet(decoder_input) + + decoder_input_final, context, alignment = self.attend(decoder_input) + + #mel_output, stop_output, alignment = self.decode(decoder_input) + decoder_rnn_output = self.decode(decoder_input_final) + if self.concat_context_to_last: + decoder_rnn_output = torch.cat( + (decoder_rnn_output, context), dim=1) + + mel_output = self.linear_projection(decoder_rnn_output) + # (B, 1) + stop_output = self.stop_layer(decoder_rnn_output) + stop_outputs += [stop_output.squeeze()] + # stop_outputs.append(stop_output) + + mel_outputs += [mel_output.squeeze(1)] + alignments += [alignment] + # print(stop_output.shape) + if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \ + and len(mel_outputs) >= min_decoder_step: + break + if len(mel_outputs) >= max_decoder_step: + # print("Warning! Decoding steps reaches max decoder steps.") + break + + decoder_input = mel_output[:,-self.num_mels:] + + + mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs( + mel_outputs, alignments, stop_outputs) + mel_outputs_stacked = [] + for mel, stop_logit in zip(mel_outputs, stop_outputs): + idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item() + mel_outputs_stacked.append(mel[:idx,:]) + mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0) + return mel_outputs, alignments diff --git a/ppg2mel/train.py b/ppg2mel/train.py new file mode 100644 index 0000000..fed7501 --- /dev/null +++ b/ppg2mel/train.py @@ -0,0 +1,67 @@ +import sys +import torch +import argparse +import numpy as np +from utils.load_yaml import HpsYaml +from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + +# For reproducibility, comment these may speed up training +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +def main(): + # Arguments + parser = argparse.ArgumentParser(description= + 'Training PPG2Mel VC model.') + parser.add_argument('--config', type=str, + help='Path to experiment config, e.g., config/vc.yaml') + parser.add_argument('--name', default=None, type=str, help='Name for logging.') + parser.add_argument('--logdir', default='log/', type=str, + help='Logging path.', required=False) + parser.add_argument('--ckpdir', default='ckpt/', type=str, + help='Checkpoint path.', required=False) + parser.add_argument('--outdir', default='result/', type=str, + help='Decode output path.', required=False) + parser.add_argument('--load', default=None, type=str, + help='Load pre-trained model (for training only)', required=False) + parser.add_argument('--warm_start', action='store_true', + help='Load model weights only, ignore specified layers.') + parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducable results.', required=False) + parser.add_argument('--njobs', default=8, type=int, + help='Number of threads for dataloader/decoding.', required=False) + parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') + parser.add_argument('--no-pin', action='store_true', + help='Disable pin-memory for dataloader') + parser.add_argument('--test', action='store_true', help='Test the model.') + parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') + parser.add_argument('--finetune', action='store_true', help='Finetune model') + parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') + parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') + parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') + + ### + + paras = parser.parse_args() + setattr(paras, 'gpu', not paras.cpu) + setattr(paras, 'pin_memory', not paras.no_pin) + setattr(paras, 'verbose', not paras.no_msg) + # Make the config dict dot visitable + config = HpsYaml(paras.config) + + np.random.seed(paras.seed) + torch.manual_seed(paras.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(paras.seed) + + print(">>> OneShot VC training ...") + mode = "train" + solver = Solver(config, paras, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/ppg2mel/train/__init__.py b/ppg2mel/train/__init__.py new file mode 100644 index 0000000..4287ca8 --- /dev/null +++ b/ppg2mel/train/__init__.py @@ -0,0 +1 @@ +# \ No newline at end of file diff --git a/ppg2mel/train/loss.py b/ppg2mel/train/loss.py new file mode 100644 index 0000000..301248c --- /dev/null +++ b/ppg2mel/train/loss.py @@ -0,0 +1,50 @@ +from typing import Dict +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..utils.nets_utils import make_pad_mask + + +class MaskedMSELoss(nn.Module): + def __init__(self, frames_per_step): + super().__init__() + self.frames_per_step = frames_per_step + self.mel_loss_criterion = nn.MSELoss(reduction='none') + # self.loss = nn.MSELoss() + self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none') + + def get_mask(self, lengths, max_len=None): + # lengths: [B,] + if max_len is None: + max_len = torch.max(lengths) + batch_size = lengths.size(0) + seq_range = torch.arange(0, max_len).long() + seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device) + seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand) + return (seq_range_expand < seq_length_expand).float() + + def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths, + stop_target, stop_pred): + ## process stop_target + B = stop_target.size(0) + stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0] + stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long() + stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step)) + + mel_trg.requires_grad = False + # (B, T, 1) + mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1) + # (B, T, D) + mel_mask = mel_mask.expand_as(mel_trg) + mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum() + mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum() + + mel_loss = mel_loss_pre + mel_loss_post + + # stop token loss + stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum() + + return mel_loss, stop_loss diff --git a/ppg2mel/train/optim.py b/ppg2mel/train/optim.py new file mode 100644 index 0000000..62533b9 --- /dev/null +++ b/ppg2mel/train/optim.py @@ -0,0 +1,45 @@ +import torch +import numpy as np + + +class Optimizer(): + def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, + **kwargs): + + # Setup torch optimizer + self.opt_type = optimizer + self.init_lr = lr + self.sch_type = lr_scheduler + opt = getattr(torch.optim, optimizer) + if lr_scheduler == 'warmup': + warmup_step = 4000.0 + init_lr = lr + self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \ + np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5) + self.opt = opt(parameters, lr=1.0) + else: + self.lr_scheduler = None + self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better? + + def get_opt_state_dict(self): + return self.opt.state_dict() + + def load_opt_state_dict(self, state_dict): + self.opt.load_state_dict(state_dict) + + def pre_step(self, step): + if self.lr_scheduler is not None: + cur_lr = self.lr_scheduler(step) + for param_group in self.opt.param_groups: + param_group['lr'] = cur_lr + else: + cur_lr = self.init_lr + self.opt.zero_grad() + return cur_lr + + def step(self): + self.opt.step() + + def create_msg(self): + return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})' + .format(self.opt_type, self.init_lr, self.sch_type)] diff --git a/ppg2mel/train/option.py b/ppg2mel/train/option.py new file mode 100644 index 0000000..f66c600 --- /dev/null +++ b/ppg2mel/train/option.py @@ -0,0 +1,10 @@ +# Default parameters which will be imported by solver +default_hparas = { + 'GRAD_CLIP': 5.0, # Grad. clip threshold + 'PROGRESS_STEP': 100, # Std. output refresh freq. + # Decode steps for objective validation (step = ratio*input_txt_len) + 'DEV_STEP_RATIO': 1.2, + # Number of examples (alignment/text) to show in tensorboard + 'DEV_N_EXAMPLE': 4, + 'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs) +} diff --git a/ppg2mel/train/solver.py b/ppg2mel/train/solver.py new file mode 100644 index 0000000..264a91c --- /dev/null +++ b/ppg2mel/train/solver.py @@ -0,0 +1,216 @@ +import os +import sys +import abc +import math +import yaml +import torch +from torch.utils.tensorboard import SummaryWriter + +from .option import default_hparas +from utils.util import human_format, Timer +from utils.load_yaml import HpsYaml + + +class BaseSolver(): + ''' + Prototype Solver for all kinds of tasks + Arguments + config - yaml-styled config + paras - argparse outcome + mode - "train"/"test" + ''' + + def __init__(self, config, paras, mode="train"): + # General Settings + self.config = config # load from yaml file + self.paras = paras # command line args + self.mode = mode # 'train' or 'test' + for k, v in default_hparas.items(): + setattr(self, k, v) + self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \ + else torch.device('cpu') + + # Name experiment + self.exp_name = paras.name + if self.exp_name is None: + if 'exp_name' in self.config: + self.exp_name = self.config.exp_name + else: + # By default, exp is named after config file + self.exp_name = paras.config.split('/')[-1].replace('.yaml', '') + if mode == 'train': + self.exp_name += '_seed{}'.format(paras.seed) + + + if mode == 'train': + # Filepath setup + os.makedirs(paras.ckpdir, exist_ok=True) + self.ckpdir = os.path.join(paras.ckpdir, self.exp_name) + os.makedirs(self.ckpdir, exist_ok=True) + + # Logger settings + self.logdir = os.path.join(paras.logdir, self.exp_name) + self.log = SummaryWriter( + self.logdir, flush_secs=self.TB_FLUSH_FREQ) + self.timer = Timer() + + # Hyper-parameters + self.step = 0 + self.valid_step = config.hparas.valid_step + self.max_step = config.hparas.max_step + + self.verbose('Exp. name : {}'.format(self.exp_name)) + self.verbose('Loading data... large corpus may took a while.') + + # elif mode == 'test': + # # Output path + # os.makedirs(paras.outdir, exist_ok=True) + # self.ckpdir = os.path.join(paras.outdir, self.exp_name) + + # Load training config to get acoustic feat and build model + # self.src_config = HpsYaml(config.src.config) + # self.paras.load = config.src.ckpt + + # self.verbose('Evaluating result of tr. config @ {}'.format( + # config.src.config)) + + def backward(self, loss): + ''' + Standard backward step with self.timer and debugger + Arguments + loss - the loss to perform loss.backward() + ''' + self.timer.set() + loss.backward() + grad_norm = torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.GRAD_CLIP) + if math.isnan(grad_norm): + self.verbose('Error : grad norm is NaN @ step '+str(self.step)) + else: + self.optimizer.step() + self.timer.cnt('bw') + return grad_norm + + def load_ckpt(self): + ''' Load ckpt if --load option is specified ''' + if self.paras.load is not None: + if self.paras.warm_start: + self.verbose(f"Warm starting model from checkpoint {self.paras.load}.") + ckpt = torch.load( + self.paras.load, map_location=self.device if self.mode == 'train' + else 'cpu') + model_dict = ckpt['model'] + if len(self.config.model.ignore_layers) > 0: + model_dict = {k:v for k, v in model_dict.items() + if k not in self.config.model.ignore_layers} + dummy_dict = self.model.state_dict() + dummy_dict.update(model_dict) + model_dict = dummy_dict + self.model.load_state_dict(model_dict) + else: + # Load weights + ckpt = torch.load( + self.paras.load, map_location=self.device if self.mode == 'train' + else 'cpu') + self.model.load_state_dict(ckpt['model']) + + # Load task-dependent items + if self.mode == 'train': + self.step = ckpt['global_step'] + self.optimizer.load_opt_state_dict(ckpt['optimizer']) + self.verbose('Load ckpt from {}, restarting at step {}'.format( + self.paras.load, self.step)) + else: + for k, v in ckpt.items(): + if type(v) is float: + metric, score = k, v + self.model.eval() + self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format( + self.paras.load, metric, score)) + + def verbose(self, msg): + ''' Verbose function for print information to stdout''' + if self.paras.verbose: + if type(msg) == list: + for m in msg: + print('[INFO]', m.ljust(100)) + else: + print('[INFO]', msg.ljust(100)) + + def progress(self, msg): + ''' Verbose function for updating progress on stdout (do not include newline) ''' + if self.paras.verbose: + sys.stdout.write("\033[K") # Clear line + print('[{}] {}'.format(human_format(self.step), msg), end='\r') + + def write_log(self, log_name, log_dict): + ''' + Write log to TensorBoard + log_name - Name of tensorboard variable + log_value - / Value of variable (e.g. dict of losses), passed if value = None + ''' + if type(log_dict) is dict: + log_dict = {key: val for key, val in log_dict.items() if ( + val is not None and not math.isnan(val))} + if log_dict is None: + pass + elif len(log_dict) > 0: + if 'align' in log_name or 'spec' in log_name: + img, form = log_dict + self.log.add_image( + log_name, img, global_step=self.step, dataformats=form) + elif 'text' in log_name or 'hyp' in log_name: + self.log.add_text(log_name, log_dict, self.step) + else: + self.log.add_scalars(log_name, log_dict, self.step) + + def save_checkpoint(self, f_name, metric, score, show_msg=True): + '''' + Ckpt saver + f_name - the name of ckpt file (w/o prefix) to store, overwrite if existed + score - The value of metric used to evaluate model + ''' + ckpt_path = os.path.join(self.ckpdir, f_name) + full_dict = { + "model": self.model.state_dict(), + "optimizer": self.optimizer.get_opt_state_dict(), + "global_step": self.step, + metric: score + } + + torch.save(full_dict, ckpt_path) + if show_msg: + self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}". + format(human_format(self.step), metric, score, ckpt_path)) + + + # ----------------------------------- Abtract Methods ------------------------------------------ # + @abc.abstractmethod + def load_data(self): + ''' + Called by main to load all data + After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set) + No return value + ''' + raise NotImplementedError + + @abc.abstractmethod + def set_model(self): + ''' + Called by main to set models + After this call, model related attributes should be setup (e.g. self.l2_loss) + The followings MUST be setup + - self.model (torch.nn.Module) + - self.optimizer (src.Optimizer), + init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas']) + Loading pre-trained model should also be performed here + No return value + ''' + raise NotImplementedError + + @abc.abstractmethod + def exec(self): + ''' + Called by main to execute training/inference + ''' + raise NotImplementedError diff --git a/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py b/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py new file mode 100644 index 0000000..daf1c6a --- /dev/null +++ b/ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py @@ -0,0 +1,288 @@ +import os, sys +# sys.path.append('/home/shaunxliu/projects/nnsp') +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.ticker import MaxNLocator +import torch +from torch.utils.data import DataLoader +import numpy as np +from .solver import BaseSolver +from utils.data_load import OneshotVcDataset, MultiSpkVcCollate +# from src.rnn_ppg2mel import BiRnnPpg2MelModel +# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL +from .loss import MaskedMSELoss +from .optim import Optimizer +from utils.util import human_format +from ppg2mel import MelDecoderMOLv2 + + +class Solver(BaseSolver): + """Customized Solver.""" + def __init__(self, config, paras, mode): + super().__init__(config, paras, mode) + self.num_att_plots = 5 + self.att_ws_dir = f"{self.logdir}/att_ws" + os.makedirs(self.att_ws_dir, exist_ok=True) + self.best_loss = np.inf + + def fetch_data(self, data): + """Move data to device""" + data = [i.to(self.device) for i in data] + return data + + def load_data(self): + """ Load data for training/validation/plotting.""" + train_dataset = OneshotVcDataset( + meta_file=self.config.data.train_fid_list, + vctk_ppg_dir=self.config.data.vctk_ppg_dir, + libri_ppg_dir=self.config.data.libri_ppg_dir, + vctk_f0_dir=self.config.data.vctk_f0_dir, + libri_f0_dir=self.config.data.libri_f0_dir, + vctk_wav_dir=self.config.data.vctk_wav_dir, + libri_wav_dir=self.config.data.libri_wav_dir, + vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir, + libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir, + ppg_file_ext=self.config.data.ppg_file_ext, + min_max_norm_mel=self.config.data.min_max_norm_mel, + mel_min=self.config.data.mel_min, + mel_max=self.config.data.mel_max, + ) + dev_dataset = OneshotVcDataset( + meta_file=self.config.data.dev_fid_list, + vctk_ppg_dir=self.config.data.vctk_ppg_dir, + libri_ppg_dir=self.config.data.libri_ppg_dir, + vctk_f0_dir=self.config.data.vctk_f0_dir, + libri_f0_dir=self.config.data.libri_f0_dir, + vctk_wav_dir=self.config.data.vctk_wav_dir, + libri_wav_dir=self.config.data.libri_wav_dir, + vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir, + libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir, + ppg_file_ext=self.config.data.ppg_file_ext, + min_max_norm_mel=self.config.data.min_max_norm_mel, + mel_min=self.config.data.mel_min, + mel_max=self.config.data.mel_max, + ) + self.train_dataloader = DataLoader( + train_dataset, + num_workers=self.paras.njobs, + shuffle=True, + batch_size=self.config.hparas.batch_size, + pin_memory=False, + drop_last=True, + collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step, + use_spk_dvec=True), + ) + self.dev_dataloader = DataLoader( + dev_dataset, + num_workers=self.paras.njobs, + shuffle=False, + batch_size=self.config.hparas.batch_size, + pin_memory=False, + drop_last=False, + collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step, + use_spk_dvec=True), + ) + self.plot_dataloader = DataLoader( + dev_dataset, + num_workers=self.paras.njobs, + shuffle=False, + batch_size=1, + pin_memory=False, + drop_last=False, + collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step, + use_spk_dvec=True, + give_uttids=True), + ) + msg = "Have prepared training set and dev set." + self.verbose(msg) + + def load_pretrained_params(self): + print("Load pretrained model from: ", self.config.data.pretrain_model_file) + ignore_layer_prefixes = ["speaker_embedding_table"] + pretrain_model_file = self.config.data.pretrain_model_file + pretrain_ckpt = torch.load( + pretrain_model_file, map_location=self.device + )["model"] + model_dict = self.model.state_dict() + print(self.model) + + # 1. filter out unnecessrary keys + for prefix in ignore_layer_prefixes: + pretrain_ckpt = {k : v + for k, v in pretrain_ckpt.items() if not k.startswith(prefix) + } + # 2. overwrite entries in the existing state dict + model_dict.update(pretrain_ckpt) + + # 3. load the new state dict + self.model.load_state_dict(model_dict) + + def set_model(self): + """Setup model and optimizer""" + # Model + print("[INFO] Model name: ", self.config["model_name"]) + self.model = MelDecoderMOLv2( + **self.config["model"] + ).to(self.device) + # self.load_pretrained_params() + + # model_params = [{'params': self.model.spk_embedding.weight}] + model_params = [{'params': self.model.parameters()}] + + # Loss criterion + self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step) + + # Optimizer + self.optimizer = Optimizer(model_params, **self.config["hparas"]) + self.verbose(self.optimizer.create_msg()) + + # Automatically load pre-trained model if self.paras.load is given + self.load_ckpt() + + def exec(self): + self.verbose("Total training steps {}.".format( + human_format(self.max_step))) + + mel_loss = None + n_epochs = 0 + # Set as current time + self.timer.set() + + while self.step < self.max_step: + for data in self.train_dataloader: + # Pre-step: updata lr_rate and do zero_grad + lr_rate = self.optimizer.pre_step(self.step) + total_loss = 0 + # data to device + ppgs, lf0_uvs, mels, in_lengths, \ + out_lengths, spk_ids, stop_tokens = self.fetch_data(data) + self.timer.cnt("rd") + mel_outputs, mel_outputs_postnet, predicted_stop = self.model( + ppgs, + in_lengths, + mels, + out_lengths, + lf0_uvs, + spk_ids + ) + mel_loss, stop_loss = self.loss_criterion( + mel_outputs, + mel_outputs_postnet, + mels, + out_lengths, + stop_tokens, + predicted_stop + ) + loss = mel_loss + stop_loss + + self.timer.cnt("fw") + + # Back-prop + grad_norm = self.backward(loss) + self.step += 1 + + # Logger + if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0): + self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}" + .format(loss.cpu().item(), mel_loss.cpu().item(), + stop_loss.cpu().item(), grad_norm, self.timer.show())) + self.write_log('loss', {'tr/loss': loss, + 'tr/mel-loss': mel_loss, + 'tr/stop-loss': stop_loss}) + + # Validation + if (self.step == 1) or (self.step % self.valid_step == 0): + self.validate() + + # End of step + # https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354 + torch.cuda.empty_cache() + self.timer.set() + if self.step > self.max_step: + break + n_epochs += 1 + self.log.close() + + def validate(self): + self.model.eval() + dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0 + + for i, data in enumerate(self.dev_dataloader): + self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader))) + # Fetch data + ppgs, lf0_uvs, mels, in_lengths, \ + out_lengths, spk_ids, stop_tokens = self.fetch_data(data) + with torch.no_grad(): + mel_outputs, mel_outputs_postnet, predicted_stop = self.model( + ppgs, + in_lengths, + mels, + out_lengths, + lf0_uvs, + spk_ids + ) + mel_loss, stop_loss = self.loss_criterion( + mel_outputs, + mel_outputs_postnet, + mels, + out_lengths, + stop_tokens, + predicted_stop + ) + loss = mel_loss + stop_loss + + dev_loss += loss.cpu().item() + dev_mel_loss += mel_loss.cpu().item() + dev_stop_loss += stop_loss.cpu().item() + + dev_loss = dev_loss / (i + 1) + dev_mel_loss = dev_mel_loss / (i + 1) + dev_stop_loss = dev_stop_loss / (i + 1) + self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False) + if dev_loss < self.best_loss: + self.best_loss = dev_loss + self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss) + self.write_log('loss', {'dv/loss': dev_loss, + 'dv/mel-loss': dev_mel_loss, + 'dv/stop-loss': dev_stop_loss}) + + # plot attention + for i, data in enumerate(self.plot_dataloader): + if i == self.num_att_plots: + break + # Fetch data + ppgs, lf0_uvs, mels, in_lengths, \ + out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1]) + fid = data[-1][0] + with torch.no_grad(): + _, _, _, att_ws = self.model( + ppgs, + in_lengths, + mels, + out_lengths, + lf0_uvs, + spk_ids, + output_att_ws=True + ) + att_ws = att_ws.squeeze(0).cpu().numpy() + att_ws = att_ws[None] + w, h = plt.figaspect(1.0 / len(att_ws)) + fig = plt.Figure(figsize=(w * 1.3, h * 1.3)) + axes = fig.subplots(1, len(att_ws)) + if len(att_ws) == 1: + axes = [axes] + + for ax, aw in zip(axes, att_ws): + ax.imshow(aw.astype(np.float32), aspect="auto") + ax.set_title(f"{fid}") + ax.set_xlabel("Input") + ax.set_ylabel("Output") + ax.xaxis.set_major_locator(MaxNLocator(integer=True)) + ax.yaxis.set_major_locator(MaxNLocator(integer=True)) + fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png" + fig.savefig(fig_name) + + # Resume training + self.model.train() + diff --git a/ppg2mel/utils/abs_model.py b/ppg2mel/utils/abs_model.py new file mode 100644 index 0000000..b6d27a6 --- /dev/null +++ b/ppg2mel/utils/abs_model.py @@ -0,0 +1,23 @@ +from abc import ABC +from abc import abstractmethod + +import torch + +class AbsMelDecoder(torch.nn.Module, ABC): + """The abstract PPG-based voice conversion class + This "model" is one of mediator objects for "Task" class. + + """ + + @abstractmethod + def forward( + self, + bottle_neck_features: torch.Tensor, + feature_lengths: torch.Tensor, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + logf0_uv: torch.Tensor = None, + spembs: torch.Tensor = None, + styleembs: torch.Tensor = None, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/ppg2mel/utils/basic_layers.py b/ppg2mel/utils/basic_layers.py new file mode 100644 index 0000000..45d80f1 --- /dev/null +++ b/ppg2mel/utils/basic_layers.py @@ -0,0 +1,79 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torch.autograd import Function + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x + +class Linear(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(Linear, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + +class Conv1d(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear', param=None): + super(Conv1d, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1)/2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param)) + + def forward(self, x): + # x: BxDxT + return self.conv(x) + + + +def tile(x, count, dim=0): + """ + Tiles x on dimension dim count times. + """ + perm = list(range(len(x.size()))) + if dim != 0: + perm[0], perm[dim] = perm[dim], perm[0] + x = x.permute(perm).contiguous() + out_size = list(x.size()) + out_size[0] *= count + batch = x.size(0) + x = x.view(batch, -1) \ + .transpose(0, 1) \ + .repeat(count, 1) \ + .transpose(0, 1) \ + .contiguous() \ + .view(*out_size) + if dim != 0: + x = x.permute(perm).contiguous() + return x diff --git a/ppg2mel/utils/cnn_postnet.py b/ppg2mel/utils/cnn_postnet.py new file mode 100644 index 0000000..1980cdd --- /dev/null +++ b/ppg2mel/utils/cnn_postnet.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .basic_layers import Linear, Conv1d + + +class Postnet(nn.Module): + """Postnet + - Five 1-d convolution with 512 channels and kernel size 5 + """ + def __init__(self, num_mels=80, + num_layers=5, + hidden_dim=512, + kernel_size=5): + super(Postnet, self).__init__() + self.convolutions = nn.ModuleList() + + self.convolutions.append( + nn.Sequential( + Conv1d( + num_mels, hidden_dim, + kernel_size=kernel_size, stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hidden_dim))) + + for i in range(1, num_layers - 1): + self.convolutions.append( + nn.Sequential( + Conv1d( + hidden_dim, + hidden_dim, + kernel_size=kernel_size, stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hidden_dim))) + + self.convolutions.append( + nn.Sequential( + Conv1d( + hidden_dim, num_mels, + kernel_size=kernel_size, stride=1, + padding=int((kernel_size - 1) / 2), + dilation=1, w_init_gain='linear'), + nn.BatchNorm1d(num_mels))) + + def forward(self, x): + # x: (B, num_mels, T_dec) + for i in range(len(self.convolutions) - 1): + x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) + x = F.dropout(self.convolutions[-1](x), 0.5, self.training) + return x diff --git a/ppg2mel/utils/mol_attention.py b/ppg2mel/utils/mol_attention.py new file mode 100644 index 0000000..8aa91f8 --- /dev/null +++ b/ppg2mel/utils/mol_attention.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class MOLAttention(nn.Module): + """ Discretized Mixture of Logistic (MOL) attention. + C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and + GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis". + """ + def __init__( + self, + query_dim, + r=1, + M=5, + ): + """ + Args: + query_dim: attention_rnn_dim. + M: number of mixtures. + """ + super().__init__() + if r < 1: + self.r = float(r) + else: + self.r = int(r) + self.M = M + self.score_mask_value = 0.0 # -float("inf") + self.eps = 1e-5 + # Position arrary for encoder time steps + self.J = None + # Query layer: [w, sigma,] + self.query_layer = torch.nn.Sequential( + nn.Linear(query_dim, 256, bias=True), + nn.ReLU(), + nn.Linear(256, 3*M, bias=True) + ) + self.mu_prev = None + self.initialize_bias() + + def initialize_bias(self): + """Initialize sigma and Delta.""" + # sigma + torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0) + # Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0 + # softplus(-0.432) = 0.5003 + if self.r == 2: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545) + elif self.r == 4: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815) + elif self.r == 1: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413) + else: + torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432) + + + def init_states(self, memory): + """Initialize mu_prev and J. + This function should be called by the decoder before decoding one batch. + Args: + memory: (B, T, D_enc) encoder output. + """ + B, T_enc, _ = memory.size() + device = memory.device + self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage + # self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float) + self.mu_prev = torch.zeros(B, self.M).to(device) + + def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None): + """ + att_rnn_h: attetion rnn hidden state. + memory: encoder outputs (B, T_enc, D). + mask: binary mask for padded data (B, T_enc). + """ + # [B, 3M] + mixture_params = self.query_layer(att_rnn_h) + + # [B, M] + w_hat = mixture_params[:, :self.M] + sigma_hat = mixture_params[:, self.M:2*self.M] + Delta_hat = mixture_params[:, 2*self.M:3*self.M] + + # print("w_hat: ", w_hat) + # print("sigma_hat: ", sigma_hat) + # print("Delta_hat: ", Delta_hat) + + # Dropout to de-correlate attention heads + w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed? + + # Mixture parameters + w = torch.softmax(w_hat, dim=-1) + self.eps + sigma = F.softplus(sigma_hat) + self.eps + Delta = F.softplus(Delta_hat) + mu_cur = self.mu_prev + Delta + # print("w:", w) + j = self.J[:memory.size(1) + 1] + + # Attention weights + # CDF of logistic distribution + phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid( + (mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1)))) + # print("phi_t:", phi_t) + + # Discretize attention weights + # (B, T_enc + 1) + alpha_t = torch.sum(phi_t, dim=1) + alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1] + alpha_t[alpha_t == 0] = self.eps + # print("alpha_t: ", alpha_t.size()) + # Apply masking + if mask is not None: + alpha_t.data.masked_fill_(mask, self.score_mask_value) + + context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1) + if memory_pitch is not None: + context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1) + + self.mu_prev = mu_cur + + if memory_pitch is not None: + return context, context_pitch, alpha_t + return context, alpha_t + diff --git a/ppg2mel/utils/nets_utils.py b/ppg2mel/utils/nets_utils.py new file mode 100644 index 0000000..098e3b4 --- /dev/null +++ b/ppg2mel/utils/nets_utils.py @@ -0,0 +1,451 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" + +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + assert isinstance(m, torch.nn.Module) + device = next(m.parameters()).device + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, :xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError('length_dim cannot be 0: {}'.format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple(slice(None) if i in (0, length_dim) else None + for i in range(xs.dim())) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), + pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == 'c': + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if 'real' not in x or 'imag' not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x['real'], x['imag']) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ("x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x))) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the training args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == 'transformer': + return np.array([1]) + + elif mode == 'mt' and arch == 'rnn': + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + logging.warning('Subsampling is not performed for machine translation.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \ + (mode == 'mt' and arch == 'rnn') or \ + (mode == 'st' and arch == 'rnn'): + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mix': + subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mulenc': + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int) + if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Encoder %d: Subsampling is not performed for vgg*. ' + 'It is performed in max pooling layers at CNN.', idx + 1) + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch)) + + +def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f'Rename: {old_prefix} -> {new_prefix}') + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v diff --git a/ppg2mel/utils/vc_utils.py b/ppg2mel/utils/vc_utils.py new file mode 100644 index 0000000..e2b6bf0 --- /dev/null +++ b/ppg2mel/utils/vc_utils.py @@ -0,0 +1,22 @@ +import torch + + +def gcd(a, b): + """Greatest common divisor.""" + a, b = (a, b) if a >=b else (b, a) + if a%b == 0: + return b + else : + return gcd(b, a%b) + +def lcm(a, b): + """Least common multiple""" + return a * b // gcd(a, b) + +def get_mask_from_lengths(lengths, max_len=None): + if max_len is None: + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) + mask = (ids < lengths.unsqueeze(1)).bool() + return mask + diff --git a/ppg_extractor/__init__.py b/ppg_extractor/__init__.py new file mode 100644 index 0000000..42a3983 --- /dev/null +++ b/ppg_extractor/__init__.py @@ -0,0 +1,102 @@ +import argparse +import torch +from pathlib import Path +import yaml + +from .frontend import DefaultFrontend +from .utterance_mvn import UtteranceMVN +from .encoder.conformer_encoder import ConformerEncoder + +_model = None # type: PPGModel +_device = None + +class PPGModel(torch.nn.Module): + def __init__( + self, + frontend, + normalizer, + encoder, + ): + super().__init__() + self.frontend = frontend + self.normalize = normalizer + self.encoder = encoder + + def forward(self, speech, speech_lengths): + """ + + Args: + speech (tensor): (B, L) + speech_lengths (tensor): (B, ) + + Returns: + bottle_neck_feats (tensor): (B, L//hop_size, 144) + + """ + feats, feats_lengths = self._extract_feats(speech, speech_lengths) + feats, feats_lengths = self.normalize(feats, feats_lengths) + encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) + return encoder_out + + def _extract_feats( + self, speech: torch.Tensor, speech_lengths: torch.Tensor + ): + assert speech_lengths.dim() == 1, speech_lengths.shape + + # for data-parallel + speech = speech[:, : speech_lengths.max()] + + if self.frontend is not None: + # Frontend + # e.g. STFT and Feature extract + # data_loader may send time-domain signal in this case + # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) + feats, feats_lengths = self.frontend(speech, speech_lengths) + else: + # No frontend and no feature extract + feats, feats_lengths = speech, speech_lengths + return feats, feats_lengths + + def extract_from_wav(self, src_wav): + src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device) + src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device) + return self(src_wav_tensor, src_wav_lengths) + + +def build_model(args): + normalizer = UtteranceMVN(**args.normalize_conf) + frontend = DefaultFrontend(**args.frontend_conf) + encoder = ConformerEncoder(input_size=80, **args.encoder_conf) + model = PPGModel(frontend, normalizer, encoder) + + return model + + +def load_model(model_file, device=None): + global _model, _device + + if device is None: + _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + else: + _device = device + # search a config file + model_config_fpaths = list(model_file.parent.rglob("*.yaml")) + config_file = model_config_fpaths[0] + with config_file.open("r", encoding="utf-8") as f: + args = yaml.safe_load(f) + + args = argparse.Namespace(**args) + + model = build_model(args) + model_state_dict = model.state_dict() + + ckpt_state_dict = torch.load(model_file, map_location=_device) + ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k} + + model_state_dict.update(ckpt_state_dict) + model.load_state_dict(model_state_dict) + + _model = model.eval().to(_device) + return _model + + diff --git a/ppg_extractor/e2e_asr_common.py b/ppg_extractor/e2e_asr_common.py new file mode 100644 index 0000000..b67f9f1 --- /dev/null +++ b/ppg_extractor/e2e_asr_common.py @@ -0,0 +1,398 @@ +#!/usr/bin/env python3 + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Common functions for ASR.""" + +import argparse +import editdistance +import json +import logging +import numpy as np +import six +import sys + +from itertools import groupby + + +def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))): + """End detection. + + desribed in Eq. (50) of S. Watanabe et al + "Hybrid CTC/Attention Architecture for End-to-End Speech Recognition" + + :param ended_hyps: + :param i: + :param M: + :param D_end: + :return: + """ + if len(ended_hyps) == 0: + return False + count = 0 + best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0] + for m in six.moves.range(M): + # get ended_hyps with their length is i - m + hyp_length = i - m + hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length] + if len(hyps_same_length) > 0: + best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0] + if best_hyp_same_length['score'] - best_hyp['score'] < D_end: + count += 1 + + if count == M: + return True + else: + return False + + +# TODO(takaaki-hori): add different smoothing methods +def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0): + """Obtain label distribution for loss smoothing. + + :param odim: + :param lsm_type: + :param blank: + :param transcript: + :return: + """ + if transcript is not None: + with open(transcript, 'rb') as f: + trans_json = json.load(f)['utts'] + + if lsm_type == 'unigram': + assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type + labelcount = np.zeros(odim) + for k, v in trans_json.items(): + ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()]) + # to avoid an error when there is no text in an uttrance + if len(ids) > 0: + labelcount[ids] += 1 + labelcount[odim - 1] = len(transcript) # count + labelcount[labelcount == 0] = 1 # flooring + labelcount[blank] = 0 # remove counts for blank + labeldist = labelcount.astype(np.float32) / np.sum(labelcount) + else: + logging.error( + "Error: unexpected label smoothing type: %s" % lsm_type) + sys.exit() + + return labeldist + + +def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True): + """Return the output size of the VGG frontend. + + :param in_channel: input channel size + :param out_channel: output channel size + :return: output size + :rtype int + """ + idim = idim / in_channel + if downsample: + idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling + idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling + return int(idim) * out_channel # numer of channels + + +class ErrorCalculator(object): + """Calculate CER and WER for E2E_ASR and CTC models during training. + + :param y_hats: numpy array with predicted text + :param y_pads: numpy array with true (target) text + :param char_list: + :param sym_space: + :param sym_blank: + :return: + """ + + def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False, + trans_type="char"): + """Construct an ErrorCalculator object.""" + super(ErrorCalculator, self).__init__() + + self.report_cer = report_cer + self.report_wer = report_wer + self.trans_type = trans_type + self.char_list = char_list + self.space = sym_space + self.blank = sym_blank + self.idx_blank = self.char_list.index(self.blank) + if self.space in self.char_list: + self.idx_space = self.char_list.index(self.space) + else: + self.idx_space = None + + def __call__(self, ys_hat, ys_pad, is_ctc=False): + """Calculate sentence-level WER/CER score. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :param bool is_ctc: calculate CER score for CTC + :return: sentence-level WER score + :rtype float + :return: sentence-level CER score + :rtype float + """ + cer, wer = None, None + if is_ctc: + return self.calculate_cer_ctc(ys_hat, ys_pad) + elif not self.report_cer and not self.report_wer: + return cer, wer + + seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad) + if self.report_cer: + cer = self.calculate_cer(seqs_hat, seqs_true) + + if self.report_wer: + wer = self.calculate_wer(seqs_hat, seqs_true) + return cer, wer + + def calculate_cer_ctc(self, ys_hat, ys_pad): + """Calculate sentence-level CER score for CTC. + + :param torch.Tensor ys_hat: prediction (batch, seqlen) + :param torch.Tensor ys_pad: reference (batch, seqlen) + :return: average sentence-level CER score + :rtype float + """ + cers, char_ref_lens = [], [] + for i, y in enumerate(ys_hat): + y_hat = [x[0] for x in groupby(y)] + y_true = ys_pad[i] + seq_hat, seq_true = [], [] + for idx in y_hat: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_hat.append(self.char_list[int(idx)]) + + for idx in y_true: + idx = int(idx) + if idx != -1 and idx != self.idx_blank and idx != self.idx_space: + seq_true.append(self.char_list[int(idx)]) + if self.trans_type == "char": + hyp_chars = "".join(seq_hat) + ref_chars = "".join(seq_true) + else: + hyp_chars = " ".join(seq_hat) + ref_chars = " ".join(seq_true) + + if len(ref_chars) > 0: + cers.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None + return cer_ctc + + def convert_to_char(self, ys_hat, ys_pad): + """Convert index to character. + + :param torch.Tensor seqs_hat: prediction (batch, seqlen) + :param torch.Tensor seqs_true: reference (batch, seqlen) + :return: token list of prediction + :rtype list + :return: token list of reference + :rtype list + """ + seqs_hat, seqs_true = [], [] + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + eos_true = np.where(y_true == -1)[0] + eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true) + # To avoid wrong higher WER than the one obtained from the decoding + # eos from y_true is used to mark the eos in y_hat + # because of that y_hats has not padded outs with -1. + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]] + seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + # seq_hat_text = "".join(seq_hat).replace(self.space, ' ') + seq_hat_text = " ".join(seq_hat).replace(self.space, ' ') + seq_hat_text = seq_hat_text.replace(self.blank, '') + # seq_true_text = "".join(seq_true).replace(self.space, ' ') + seq_true_text = " ".join(seq_true).replace(self.space, ' ') + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + return seqs_hat, seqs_true + + def calculate_cer(self, seqs_hat, seqs_true): + """Calculate sentence-level CER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level CER score + :rtype float + """ + char_eds, char_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_chars = seq_hat_text.replace(' ', '') + ref_chars = seq_true_text.replace(' ', '') + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + return float(sum(char_eds)) / sum(char_ref_lens) + + def calculate_wer(self, seqs_hat, seqs_true): + """Calculate sentence-level WER score. + + :param list seqs_hat: prediction + :param list seqs_true: reference + :return: average sentence-level WER score + :rtype float + """ + word_eds, word_ref_lens = [], [] + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + return float(sum(word_eds)) / sum(word_ref_lens) + + +class ErrorCalculatorTrans(object): + """Calculate CER and WER for transducer models. + + Args: + decoder (nn.Module): decoder module + args (Namespace): argument Namespace containing options + report_cer (boolean): compute CER option + report_wer (boolean): compute WER option + + """ + + def __init__(self, decoder, args, report_cer=False, report_wer=False): + """Construct an ErrorCalculator object for transducer model.""" + super(ErrorCalculatorTrans, self).__init__() + + self.dec = decoder + + recog_args = {'beam_size': args.beam_size, + 'nbest': args.nbest, + 'space': args.sym_space, + 'score_norm_transducer': args.score_norm_transducer} + + self.recog_args = argparse.Namespace(**recog_args) + + self.char_list = args.char_list + self.space = args.sym_space + self.blank = args.sym_blank + + self.report_cer = args.report_cer + self.report_wer = args.report_wer + + def __call__(self, hs_pad, ys_pad): + """Calculate sentence-level WER/CER score for transducer models. + + Args: + hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D) + ys_pad (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): sentence-level CER score + (float): sentence-level WER score + + """ + cer, wer = None, None + + if not self.report_cer and not self.report_wer: + return cer, wer + + batchsize = int(hs_pad.size(0)) + batch_nbest = [] + + for b in six.moves.range(batchsize): + if self.recog_args.beam_size == 1: + nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args) + else: + nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args) + batch_nbest.append(nbest_hyps) + + ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest] + + seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu()) + + if self.report_cer: + cer = self.calculate_cer(seqs_hat, seqs_true) + + if self.report_wer: + wer = self.calculate_wer(seqs_hat, seqs_true) + + return cer, wer + + def convert_to_char(self, ys_hat, ys_pad): + """Convert index to character. + + Args: + ys_hat (torch.Tensor): prediction (batch, seqlen) + ys_pad (torch.Tensor): reference (batch, seqlen) + + Returns: + (list): token list of prediction + (list): token list of reference + + """ + seqs_hat, seqs_true = [], [] + + for i, y_hat in enumerate(ys_hat): + y_true = ys_pad[i] + + eos_true = np.where(y_true == -1)[0] + eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true) + + seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]] + seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1] + + seq_hat_text = "".join(seq_hat).replace(self.space, ' ') + seq_hat_text = seq_hat_text.replace(self.blank, '') + seq_true_text = "".join(seq_true).replace(self.space, ' ') + + seqs_hat.append(seq_hat_text) + seqs_true.append(seq_true_text) + + return seqs_hat, seqs_true + + def calculate_cer(self, seqs_hat, seqs_true): + """Calculate sentence-level CER score for transducer model. + + Args: + seqs_hat (torch.Tensor): prediction (batch, seqlen) + seqs_true (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): average sentence-level CER score + + """ + char_eds, char_ref_lens = [], [] + + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_chars = seq_hat_text.replace(' ', '') + ref_chars = seq_true_text.replace(' ', '') + + char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_ref_lens.append(len(ref_chars)) + + return float(sum(char_eds)) / sum(char_ref_lens) + + def calculate_wer(self, seqs_hat, seqs_true): + """Calculate sentence-level WER score for transducer model. + + Args: + seqs_hat (torch.Tensor): prediction (batch, seqlen) + seqs_true (torch.Tensor): reference (batch, seqlen) + + Returns: + (float): average sentence-level WER score + + """ + word_eds, word_ref_lens = [], [] + + for i, seq_hat_text in enumerate(seqs_hat): + seq_true_text = seqs_true[i] + hyp_words = seq_hat_text.split() + ref_words = seq_true_text.split() + + word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_ref_lens.append(len(ref_words)) + + return float(sum(word_eds)) / sum(word_ref_lens) diff --git a/ppg_extractor/encoder/__init__.py b/ppg_extractor/encoder/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/ppg_extractor/encoder/attention.py b/ppg_extractor/encoder/attention.py new file mode 100644 index 0000000..4e7a0d5 --- /dev/null +++ b/ppg_extractor/encoder/attention.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Multi-Head Attention layer definition.""" + +import math + +import numpy +import torch +from torch import nn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer. + + :param int n_head: the number of head s + :param int n_feat: the number of features + :param float dropout_rate: dropout rate + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an MultiHeadedAttention object.""" + super(MultiHeadedAttention, self).__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat) + self.linear_k = nn.Linear(n_feat, n_feat) + self.linear_v = nn.Linear(n_feat, n_feat) + self.linear_out = nn.Linear(n_feat, n_feat) + self.attn = None + self.dropout = nn.Dropout(p=dropout_rate) + + def forward_qkv(self, query, key, value): + """Transform query, key and value. + + :param torch.Tensor query: (batch, time1, size) + :param torch.Tensor key: (batch, time2, size) + :param torch.Tensor value: (batch, time2, size) + :return torch.Tensor transformed query, key and value + + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) # (batch, head, time1, d_k) + k = k.transpose(1, 2) # (batch, head, time2, d_k) + v = v.transpose(1, 2) # (batch, head, time2, d_k) + + return q, k, v + + def forward_attention(self, value, scores, mask): + """Compute attention context vector. + + :param torch.Tensor value: (batch, head, time2, size) + :param torch.Tensor scores: (batch, head, time1, time2) + :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2) + :return torch.Tensor transformed `value` (batch, time1, d_model) + weighted by the attention score (batch, time1, time2) + + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) + min_value = float( + numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min + ) + scores = scores.masked_fill(mask, min_value) + self.attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + p_attn = self.dropout(self.attn) + x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) + x = ( + x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward(self, query, key, value, mask): + """Compute 'Scaled Dot Product Attention'. + + :param torch.Tensor query: (batch, time1, size) + :param torch.Tensor key: (batch, time2, size) + :param torch.Tensor value: (batch, time2, size) + :param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2) + :param torch.nn.Dropout dropout: + :return torch.Tensor: attention output (batch, time1, d_model) + """ + q, k, v = self.forward_qkv(query, key, value) + scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadedAttention(MultiHeadedAttention): + """Multi-Head Attention layer with relative position encoding. + + Paper: https://arxiv.org/abs/1901.02860 + + :param int n_head: the number of head s + :param int n_feat: the number of features + :param float dropout_rate: dropout rate + + """ + + def __init__(self, n_head, n_feat, dropout_rate): + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__(n_head, n_feat, dropout_rate) + # linear transformation for positional ecoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k)) + self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k)) + torch.nn.init.xavier_uniform_(self.pos_bias_u) + torch.nn.init.xavier_uniform_(self.pos_bias_v) + + def rel_shift(self, x, zero_triu=False): + """Compute relative positinal encoding. + + :param torch.Tensor x: (batch, time, size) + :param bool zero_triu: return the lower triangular part of the matrix + """ + zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2)) + x = x_padded[:, :, 1:].view_as(x) + + if zero_triu: + ones = torch.ones((x.size(2), x.size(3))) + x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :] + + return x + + def forward(self, query, key, value, pos_emb, mask): + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + + :param torch.Tensor query: (batch, time1, size) + :param torch.Tensor key: (batch, time2, size) + :param torch.Tensor value: (batch, time2, size) + :param torch.Tensor pos_emb: (batch, time1, size) + :param torch.Tensor mask: (batch, time1, time2) + :param torch.nn.Dropout dropout: + :return torch.Tensor: attention output (batch, time1, d_model) + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + scores = (matrix_ac + matrix_bd) / math.sqrt( + self.d_k + ) # (batch, head, time1, time2) + + return self.forward_attention(v, scores, mask) diff --git a/ppg_extractor/encoder/conformer_encoder.py b/ppg_extractor/encoder/conformer_encoder.py new file mode 100644 index 0000000..d31e97a --- /dev/null +++ b/ppg_extractor/encoder/conformer_encoder.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" + +import logging +import torch +from typing import Callable +from typing import Collection +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +from .convolution import ConvolutionModule +from .encoder_layer import EncoderLayer +from ..nets_utils import get_activation, make_pad_mask +from .vgg import VGG2L +from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention +from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding +from .layer_norm import LayerNorm +from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d +from .positionwise_feed_forward import PositionwiseFeedForward +from .repeat import repeat +from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling + + +class ConformerEncoder(torch.nn.Module): + """Conformer encoder module. + + :param int idim: input dim + :param int attention_dim: dimention of attention + :param int attention_heads: the number of heads of multi head attention + :param int linear_units: the number of units of position-wise feed forward + :param int num_blocks: the number of decoder blocks + :param float dropout_rate: dropout rate + :param float attention_dropout_rate: dropout rate in attention + :param float positional_dropout_rate: dropout rate after adding positional encoding + :param str or torch.nn.Module input_layer: input layer type + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + :param str positionwise_layer_type: linear of conv1d + :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + :param str encoder_pos_enc_layer_type: encoder positional encoding layer type + :param str encoder_attn_layer_type: encoder attention layer type + :param str activation_type: encoder activation function type + :param bool macaron_style: whether to use macaron style for positionwise layer + :param bool use_cnn_module: whether to use convolution module + :param int cnn_module_kernel: kernerl size of convolution module + :param int padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + input_size, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + cnn_module_kernel=31, + padding_idx=-1, + no_subsample=False, + subsample_by_2=False, + ): + """Construct an Encoder object.""" + super().__init__() + + self._output_size = attention_dim + idim = input_size + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + logging.info("Encoder input layer type: conv2d") + if no_subsample: + self.embed = Conv2dNoSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + else: + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + subsample_by_2, # NOTE(Sx): added by songxiang + ) + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + attention_dim, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + prev_states: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + xs_pad: input tensor (B, L, D) + ilens: input lengths (B) + prev_states: Not to be used now. + Returns: + Position embedded tensor and mask + """ + masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) + + if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)): + # print(xs_pad.shape) + xs_pad, masks = self.embed(xs_pad, masks) + # print(xs_pad[0].size()) + else: + xs_pad = self.embed(xs_pad) + xs_pad, masks = self.encoders(xs_pad, masks) + if isinstance(xs_pad, tuple): + xs_pad = xs_pad[0] + + if self.normalize_before: + xs_pad = self.after_norm(xs_pad) + olens = masks.squeeze(1).sum(1) + return xs_pad, olens, None + + # def forward(self, xs, masks): + # """Encode input sequence. + + # :param torch.Tensor xs: input tensor + # :param torch.Tensor masks: input mask + # :return: position embedded tensor and mask + # :rtype Tuple[torch.Tensor, torch.Tensor]: + # """ + # if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + # xs, masks = self.embed(xs, masks) + # else: + # xs = self.embed(xs) + + # xs, masks = self.encoders(xs, masks) + # if isinstance(xs, tuple): + # xs = xs[0] + + # if self.normalize_before: + # xs = self.after_norm(xs) + # return xs, masks diff --git a/ppg_extractor/encoder/convolution.py b/ppg_extractor/encoder/convolution.py new file mode 100644 index 0000000..2d2c399 --- /dev/null +++ b/ppg_extractor/encoder/convolution.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""ConvolutionModule definition.""" + +from torch import nn + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + + :param int channels: channels of cnn + :param int kernel_size: kernerl size of cnn + + """ + + def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True): + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = activation + + def forward(self, x): + """Compute convolution module. + + :param torch.Tensor x: (batch, time, size) + :return torch.Tensor: convoluted `value` (batch, time, d_model) + """ + # exchange the temporal dimension and the feature dimension + x = x.transpose(1, 2) + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channel, dim) + x = nn.functional.glu(x, dim=1) # (batch, channel, dim) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) + + return x.transpose(1, 2) diff --git a/ppg_extractor/encoder/embedding.py b/ppg_extractor/encoder/embedding.py new file mode 100644 index 0000000..fa3199c --- /dev/null +++ b/ppg_extractor/encoder/embedding.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positonal Encoding Module.""" + +import math + +import torch + + +def _pre_hook( + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + :param reverse: whether to reverse the input position + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class ScaledPositionalEncoding(PositionalEncoding): + """Scaled positional encoding module. + + See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + """ + super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) + self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + + def reset_parameters(self): + """Reset parameters.""" + self.alpha.data = torch.tensor(1.0) + + def forward(self, x): + """Add positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + + """ + self.extend_pe(x) + x = x + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(PositionalEncoding): + """Relitive positional encoding module. + + See : Appendix B in https://arxiv.org/abs/1901.02860 + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class. + + :param int d_model: embedding dim + :param float dropout_rate: dropout rate + :param int max_len: maximum input length + + """ + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, x): + """Compute positional encoding. + + Args: + x (torch.Tensor): Input. Its shape is (batch, time, ...) + + Returns: + torch.Tensor: x. Its shape is (batch, time, ...) + torch.Tensor: pos_emb. Its shape is (1, time, ...) + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, : x.size(1)] + return self.dropout(x), self.dropout(pos_emb) diff --git a/ppg_extractor/encoder/encoder.py b/ppg_extractor/encoder/encoder.py new file mode 100644 index 0000000..6b92c01 --- /dev/null +++ b/ppg_extractor/encoder/encoder.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder definition.""" + +import logging +import torch + +from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule +from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer +from espnet.nets.pytorch_backend.nets_utils import get_activation +from espnet.nets.pytorch_backend.transducer.vgg import VGG2L +from espnet.nets.pytorch_backend.transformer.attention import ( + MultiHeadedAttention, # noqa: H301 + RelPositionMultiHeadedAttention, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.embedding import ( + PositionalEncoding, # noqa: H301 + ScaledPositionalEncoding, # noqa: H301 + RelPositionalEncoding, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear +from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d +from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import ( + PositionwiseFeedForward, # noqa: H301 +) +from espnet.nets.pytorch_backend.transformer.repeat import repeat +from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling + + +class Encoder(torch.nn.Module): + """Conformer encoder module. + + :param int idim: input dim + :param int attention_dim: dimention of attention + :param int attention_heads: the number of heads of multi head attention + :param int linear_units: the number of units of position-wise feed forward + :param int num_blocks: the number of decoder blocks + :param float dropout_rate: dropout rate + :param float attention_dropout_rate: dropout rate in attention + :param float positional_dropout_rate: dropout rate after adding positional encoding + :param str or torch.nn.Module input_layer: input layer type + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + :param str positionwise_layer_type: linear of conv1d + :param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer + :param str encoder_pos_enc_layer_type: encoder positional encoding layer type + :param str encoder_attn_layer_type: encoder attention layer type + :param str activation_type: encoder activation function type + :param bool macaron_style: whether to use macaron style for positionwise layer + :param bool use_cnn_module: whether to use convolution module + :param int cnn_module_kernel: kernerl size of convolution module + :param int padding_idx: padding_idx for input_layer=embed + """ + + def __init__( + self, + idim, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + positional_dropout_rate=0.1, + attention_dropout_rate=0.0, + input_layer="conv2d", + normalize_before=True, + concat_after=False, + positionwise_layer_type="linear", + positionwise_conv_kernel_size=1, + macaron_style=False, + pos_enc_layer_type="abs_pos", + selfattention_layer_type="selfattn", + activation_type="swish", + use_cnn_module=False, + cnn_module_kernel=31, + padding_idx=-1, + ): + """Construct an Encoder object.""" + super(Encoder, self).__init__() + + activation = get_activation(activation_type) + if pos_enc_layer_type == "abs_pos": + pos_enc_class = PositionalEncoding + elif pos_enc_layer_type == "scaled_abs_pos": + pos_enc_class = ScaledPositionalEncoding + elif pos_enc_layer_type == "rel_pos": + assert selfattention_layer_type == "rel_selfattn" + pos_enc_class = RelPositionalEncoding + else: + raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type) + + if input_layer == "linear": + self.embed = torch.nn.Sequential( + torch.nn.Linear(idim, attention_dim), + torch.nn.LayerNorm(attention_dim), + torch.nn.Dropout(dropout_rate), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "conv2d": + self.embed = Conv2dSubsampling( + idim, + attention_dim, + dropout_rate, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer == "vgg2l": + self.embed = VGG2L(idim, attention_dim) + elif input_layer == "embed": + self.embed = torch.nn.Sequential( + torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx), + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif isinstance(input_layer, torch.nn.Module): + self.embed = torch.nn.Sequential( + input_layer, + pos_enc_class(attention_dim, positional_dropout_rate), + ) + elif input_layer is None: + self.embed = torch.nn.Sequential( + pos_enc_class(attention_dim, positional_dropout_rate) + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + self.normalize_before = normalize_before + if positionwise_layer_type == "linear": + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + attention_dim, + linear_units, + dropout_rate, + activation, + ) + elif positionwise_layer_type == "conv1d": + positionwise_layer = MultiLayeredConv1d + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + elif positionwise_layer_type == "conv1d-linear": + positionwise_layer = Conv1dLinear + positionwise_layer_args = ( + attention_dim, + linear_units, + positionwise_conv_kernel_size, + dropout_rate, + ) + else: + raise NotImplementedError("Support only linear or conv1d.") + + if selfattention_layer_type == "selfattn": + logging.info("encoder self-attention layer type = self-attention") + encoder_selfattn_layer = MultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + elif selfattention_layer_type == "rel_selfattn": + assert pos_enc_layer_type == "rel_pos" + encoder_selfattn_layer = RelPositionMultiHeadedAttention + encoder_selfattn_layer_args = ( + attention_heads, + attention_dim, + attention_dropout_rate, + ) + else: + raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type) + + convolution_layer = ConvolutionModule + convolution_layer_args = (attention_dim, cnn_module_kernel, activation) + + self.encoders = repeat( + num_blocks, + lambda lnum: EncoderLayer( + attention_dim, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + positionwise_layer(*positionwise_layer_args) if macaron_style else None, + convolution_layer(*convolution_layer_args) if use_cnn_module else None, + dropout_rate, + normalize_before, + concat_after, + ), + ) + if self.normalize_before: + self.after_norm = LayerNorm(attention_dim) + + def forward(self, xs, masks): + """Encode input sequence. + + :param torch.Tensor xs: input tensor + :param torch.Tensor masks: input mask + :return: position embedded tensor and mask + :rtype Tuple[torch.Tensor, torch.Tensor]: + """ + if isinstance(self.embed, (Conv2dSubsampling, VGG2L)): + xs, masks = self.embed(xs, masks) + else: + xs = self.embed(xs) + + xs, masks = self.encoders(xs, masks) + if isinstance(xs, tuple): + xs = xs[0] + + if self.normalize_before: + xs = self.after_norm(xs) + return xs, masks diff --git a/ppg_extractor/encoder/encoder_layer.py b/ppg_extractor/encoder/encoder_layer.py new file mode 100644 index 0000000..750a32e --- /dev/null +++ b/ppg_extractor/encoder/encoder_layer.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Encoder self-attention layer definition.""" + +import torch + +from torch import nn + +from .layer_norm import LayerNorm + + +class EncoderLayer(nn.Module): + """Encoder layer module. + + :param int size: input dim + :param espnet.nets.pytorch_backend.transformer.attention. + MultiHeadedAttention self_attn: self attention module + RelPositionMultiHeadedAttention self_attn: self attention module + :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward. + PositionwiseFeedForward feed_forward: + feed forward module + :param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward + for macaron style + PositionwiseFeedForward feed_forward: + feed forward module + :param espnet.nets.pytorch_backend.conformer.convolution. + ConvolutionModule feed_foreard: + feed forward module + :param float dropout_rate: dropout rate + :param bool normalize_before: whether to use layer_norm before the first block + :param bool concat_after: whether to concat attention layer's input and output + if True, additional linear will be applied. + i.e. x -> x + linear(concat(x, att(x))) + if False, no additional linear will be applied. i.e. x -> x + att(x) + + """ + + def __init__( + self, + size, + self_attn, + feed_forward, + feed_forward_macaron, + conv_module, + dropout_rate, + normalize_before=True, + concat_after=False, + ): + """Construct an EncoderLayer object.""" + super(EncoderLayer, self).__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.feed_forward_macaron = feed_forward_macaron + self.conv_module = conv_module + self.norm_ff = LayerNorm(size) # for the FNN module + self.norm_mha = LayerNorm(size) # for the MHA module + if feed_forward_macaron is not None: + self.norm_ff_macaron = LayerNorm(size) + self.ff_scale = 0.5 + else: + self.ff_scale = 1.0 + if self.conv_module is not None: + self.norm_conv = LayerNorm(size) # for the CNN module + self.norm_final = LayerNorm(size) # for the final output of the block + self.dropout = nn.Dropout(dropout_rate) + self.size = size + self.normalize_before = normalize_before + self.concat_after = concat_after + if self.concat_after: + self.concat_linear = nn.Linear(size + size, size) + + def forward(self, x_input, mask, cache=None): + """Compute encoded features. + + :param torch.Tensor x_input: encoded source features, w/o pos_emb + tuple((batch, max_time_in, size), (1, max_time_in, size)) + or (batch, max_time_in, size) + :param torch.Tensor mask: mask for x (batch, max_time_in) + :param torch.Tensor cache: cache for x (batch, max_time_in - 1, size) + :rtype: Tuple[torch.Tensor, torch.Tensor] + """ + if isinstance(x_input, tuple): + x, pos_emb = x_input[0], x_input[1] + else: + x, pos_emb = x_input, None + + # whether to use macaron style + if self.feed_forward_macaron is not None: + residual = x + if self.normalize_before: + x = self.norm_ff_macaron(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x)) + if not self.normalize_before: + x = self.norm_ff_macaron(x) + + # multi-headed self-attention module + residual = x + if self.normalize_before: + x = self.norm_mha(x) + + if cache is None: + x_q = x + else: + assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) + x_q = x[:, -1:, :] + residual = residual[:, -1:, :] + mask = None if mask is None else mask[:, -1:, :] + + if pos_emb is not None: + x_att = self.self_attn(x_q, x, x, pos_emb, mask) + else: + x_att = self.self_attn(x_q, x, x, mask) + + if self.concat_after: + x_concat = torch.cat((x, x_att), dim=-1) + x = residual + self.concat_linear(x_concat) + else: + x = residual + self.dropout(x_att) + if not self.normalize_before: + x = self.norm_mha(x) + + # convolution module + if self.conv_module is not None: + residual = x + if self.normalize_before: + x = self.norm_conv(x) + x = residual + self.dropout(self.conv_module(x)) + if not self.normalize_before: + x = self.norm_conv(x) + + # feed forward module + residual = x + if self.normalize_before: + x = self.norm_ff(x) + x = residual + self.ff_scale * self.dropout(self.feed_forward(x)) + if not self.normalize_before: + x = self.norm_ff(x) + + if self.conv_module is not None: + x = self.norm_final(x) + + if cache is not None: + x = torch.cat([cache, x], dim=1) + + if pos_emb is not None: + return (x, pos_emb), mask + + return x, mask diff --git a/ppg_extractor/encoder/layer_norm.py b/ppg_extractor/encoder/layer_norm.py new file mode 100644 index 0000000..db8be30 --- /dev/null +++ b/ppg_extractor/encoder/layer_norm.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer normalization module.""" + +import torch + + +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) diff --git a/ppg_extractor/encoder/multi_layer_conv.py b/ppg_extractor/encoder/multi_layer_conv.py new file mode 100644 index 0000000..fdb7fe7 --- /dev/null +++ b/ppg_extractor/encoder/multi_layer_conv.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer).""" + +import torch + + +class MultiLayeredConv1d(torch.nn.Module): + """Multi-layered conv1d for Transformer block. + + This is a module of multi-leyered conv1d designed + to replace positionwise feed-forward network + in Transforner block, which is introduced in + `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize MultiLayeredConv1d module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(MultiLayeredConv1d, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Conv1d( + hidden_chans, + in_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Batch of input tensors (B, ..., in_chans). + + Returns: + Tensor: Batch of output tensors (B, ..., hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1) + + +class Conv1dLinear(torch.nn.Module): + """Conv1D + Linear for Transformer block. + + A variant of MultiLayeredConv1d, which replaces second conv-layer to linear. + + """ + + def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate): + """Initialize Conv1dLinear module. + + Args: + in_chans (int): Number of input channels. + hidden_chans (int): Number of hidden channels. + kernel_size (int): Kernel size of conv1d. + dropout_rate (float): Dropout rate. + + """ + super(Conv1dLinear, self).__init__() + self.w_1 = torch.nn.Conv1d( + in_chans, + hidden_chans, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + ) + self.w_2 = torch.nn.Linear(hidden_chans, in_chans) + self.dropout = torch.nn.Dropout(dropout_rate) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Batch of input tensors (B, ..., in_chans). + + Returns: + Tensor: Batch of output tensors (B, ..., hidden_chans). + + """ + x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1) + return self.w_2(self.dropout(x)) diff --git a/ppg_extractor/encoder/positionwise_feed_forward.py b/ppg_extractor/encoder/positionwise_feed_forward.py new file mode 100644 index 0000000..7a9237a --- /dev/null +++ b/ppg_extractor/encoder/positionwise_feed_forward.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Positionwise feed forward layer definition.""" + +import torch + + +class PositionwiseFeedForward(torch.nn.Module): + """Positionwise feed forward layer. + + :param int idim: input dimenstion + :param int hidden_units: number of hidden units + :param float dropout_rate: dropout rate + + """ + + def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()): + """Construct an PositionwiseFeedForward object.""" + super(PositionwiseFeedForward, self).__init__() + self.w_1 = torch.nn.Linear(idim, hidden_units) + self.w_2 = torch.nn.Linear(hidden_units, idim) + self.dropout = torch.nn.Dropout(dropout_rate) + self.activation = activation + + def forward(self, x): + """Forward funciton.""" + return self.w_2(self.dropout(self.activation(self.w_1(x)))) diff --git a/ppg_extractor/encoder/repeat.py b/ppg_extractor/encoder/repeat.py new file mode 100644 index 0000000..7a8af6c --- /dev/null +++ b/ppg_extractor/encoder/repeat.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Repeat the same layer definition.""" + +import torch + + +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential.""" + + def forward(self, *args): + """Repeat.""" + for m in self: + args = m(*args) + return args + + +def repeat(N, fn): + """Repeat module N times. + + :param int N: repeat time + :param function fn: function to generate module + :return: repeated modules + :rtype: MultiSequential + """ + return MultiSequential(*[fn(n) for n in range(N)]) diff --git a/ppg_extractor/encoder/subsampling.py b/ppg_extractor/encoder/subsampling.py new file mode 100644 index 0000000..e754126 --- /dev/null +++ b/ppg_extractor/encoder/subsampling.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2019 Shigeki Karita +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Subsampling layer definition.""" +import logging +import torch + +from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding + + +class Conv2dSubsampling(torch.nn.Module): + """Convolutional 2D subsampling (to 1/4 length or 1/2 length). + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + :param torch.nn.Module pos_enc: custom position encoding layer + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None, + subsample_by_2=False, + ): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling, self).__init__() + self.subsample_by_2 = subsample_by_2 + if subsample_by_2: + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (idim // 2), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + else: + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (idim // 4), odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + if self.subsample_by_2: + return x, x_mask[:, :, ::2] + else: + return x, x_mask[:, :, ::2][:, :, ::2] + + def __getitem__(self, key): + """Subsample x. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dNoSubsampling(torch.nn.Module): + """Convolutional 2D without subsampling. + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + :param torch.nn.Module pos_enc: custom position encoding layer + + """ + + def __init__(self, idim, odim, dropout_rate, pos_enc=None): + """Construct an Conv2dSubsampling object.""" + super().__init__() + logging.info("Encoder does not do down-sample on mel-spectrogram.") + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * idim, odim), + pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask + + def __getitem__(self, key): + """Subsample x. + + When reset_parameters() is called, if use_scaled_pos_enc is used, + return the positioning encoding. + + """ + if key != -1: + raise NotImplementedError("Support only `-1` (for `reset_parameters`).") + return self.out[key] + + +class Conv2dSubsampling6(torch.nn.Module): + """Convolutional 2D subsampling (to 1/6 length). + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + + """ + + def __init__(self, idim, odim, dropout_rate): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling6, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 5, 3), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim), + PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-4:3] + + +class Conv2dSubsampling8(torch.nn.Module): + """Convolutional 2D subsampling (to 1/8 length). + + :param int idim: input dim + :param int odim: output dim + :param flaot dropout_rate: dropout rate + + """ + + def __init__(self, idim, odim, dropout_rate): + """Construct an Conv2dSubsampling object.""" + super(Conv2dSubsampling8, self).__init__() + self.conv = torch.nn.Sequential( + torch.nn.Conv2d(1, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + torch.nn.Conv2d(odim, odim, 3, 2), + torch.nn.ReLU(), + ) + self.out = torch.nn.Sequential( + torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim), + PositionalEncoding(odim, dropout_rate), + ) + + def forward(self, x, x_mask): + """Subsample x. + + :param torch.Tensor x: input tensor + :param torch.Tensor x_mask: input mask + :return: subsampled x and mask + :rtype Tuple[torch.Tensor, torch.Tensor] + """ + x = x.unsqueeze(1) # (b, c, t, f) + x = self.conv(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + if x_mask is None: + return x, None + return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2] diff --git a/ppg_extractor/encoder/swish.py b/ppg_extractor/encoder/swish.py new file mode 100644 index 0000000..c53a7a9 --- /dev/null +++ b/ppg_extractor/encoder/swish.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +# Copyright 2020 Johns Hopkins University (Shinji Watanabe) +# Northwestern Polytechnical University (Pengcheng Guo) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +"""Swish() activation function for Conformer.""" + +import torch + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x): + """Return Swich activation function.""" + return x * torch.sigmoid(x) diff --git a/ppg_extractor/encoder/vgg.py b/ppg_extractor/encoder/vgg.py new file mode 100644 index 0000000..5ca1c65 --- /dev/null +++ b/ppg_extractor/encoder/vgg.py @@ -0,0 +1,77 @@ +"""VGG2L definition for transformer-transducer.""" + +import torch + + +class VGG2L(torch.nn.Module): + """VGG2L module for transformer-transducer encoder.""" + + def __init__(self, idim, odim): + """Construct a VGG2L object. + + Args: + idim (int): dimension of inputs + odim (int): dimension of outputs + + """ + super(VGG2L, self).__init__() + + self.vgg2l = torch.nn.Sequential( + torch.nn.Conv2d(1, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(64, 64, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((3, 2)), + torch.nn.Conv2d(64, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(128, 128, 3, stride=1, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d((2, 2)), + ) + + self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim) + + def forward(self, x, x_mask): + """VGG2L forward for x. + + Args: + x (torch.Tensor): input torch (B, T, idim) + x_mask (torch.Tensor): (B, 1, T) + + Returns: + x (torch.Tensor): input torch (B, sub(T), attention_dim) + x_mask (torch.Tensor): (B, 1, sub(T)) + + """ + x = x.unsqueeze(1) + x = self.vgg2l(x) + + b, c, t, f = x.size() + + x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + if x_mask is None: + return x, None + else: + x_mask = self.create_new_mask(x_mask, x) + + return x, x_mask + + def create_new_mask(self, x_mask, x): + """Create a subsampled version of x_mask. + + Args: + x_mask (torch.Tensor): (B, 1, T) + x (torch.Tensor): (B, sub(T), attention_dim) + + Returns: + x_mask (torch.Tensor): (B, 1, sub(T)) + + """ + x_t1 = x_mask.size(2) - (x_mask.size(2) % 3) + x_mask = x_mask[:, :, :x_t1][:, :, ::3] + + x_t2 = x_mask.size(2) - (x_mask.size(2) % 2) + x_mask = x_mask[:, :, :x_t2][:, :, ::2] + + return x_mask diff --git a/ppg_extractor/encoders.py b/ppg_extractor/encoders.py new file mode 100644 index 0000000..526140f --- /dev/null +++ b/ppg_extractor/encoders.py @@ -0,0 +1,298 @@ +import logging +import six + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn.utils.rnn import pack_padded_sequence +from torch.nn.utils.rnn import pad_packed_sequence + +from .e2e_asr_common import get_vgg2l_odim +from .nets_utils import make_pad_mask, to_device + + +class RNNP(torch.nn.Module): + """RNN with projection layer module + + :param int idim: dimension of inputs + :param int elayers: number of encoder layers + :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) + :param int hdim: number of projection units + :param np.ndarray subsample: list of subsampling numbers + :param float dropout: dropout rate + :param str typ: The RNN type + """ + + def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"): + super(RNNP, self).__init__() + bidir = typ[0] == "b" + for i in six.moves.range(elayers): + if i == 0: + inputdim = idim + else: + inputdim = hdim + rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, + batch_first=True) if "lstm" in typ \ + else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True) + setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn) + # bottleneck layer to merge + if bidir: + setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim)) + else: + setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim)) + + self.elayers = elayers + self.cdim = cdim + self.subsample = subsample + self.typ = typ + self.bidir = bidir + + def forward(self, xs_pad, ilens, prev_state=None): + """RNNP forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous RNN states + :return: batch of hidden state sequences (B, Tmax, hdim) + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens)) + elayer_states = [] + for layer in six.moves.range(self.elayers): + xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False) + rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer)) + rnn.flatten_parameters() + if prev_state is not None and rnn.bidirectional: + prev_state = reset_backward_rnn_state(prev_state) + ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer]) + elayer_states.append(states) + # ys: utt list of frame x cdim x 2 (2: means bidirectional) + ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) + sub = self.subsample[layer + 1] + if sub > 1: + ys_pad = ys_pad[:, ::sub] + ilens = [int(i + 1) // sub for i in ilens] + # (sum _utt frame_utt) x dim + projected = getattr(self, 'bt' + str(layer) + )(ys_pad.contiguous().view(-1, ys_pad.size(2))) + if layer == self.elayers - 1: + xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) + else: + xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1)) + + return xs_pad, ilens, elayer_states # x: utt list of frame x dim + + +class RNN(torch.nn.Module): + """RNN module + + :param int idim: dimension of inputs + :param int elayers: number of encoder layers + :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional) + :param int hdim: number of final projection units + :param float dropout: dropout rate + :param str typ: The RNN type + """ + + def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"): + super(RNN, self).__init__() + bidir = typ[0] == "b" + self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True, + dropout=dropout, bidirectional=bidir) if "lstm" in typ \ + else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout, + bidirectional=bidir) + if bidir: + self.l_last = torch.nn.Linear(cdim * 2, hdim) + else: + self.l_last = torch.nn.Linear(cdim, hdim) + self.typ = typ + + def forward(self, xs_pad, ilens, prev_state=None): + """RNN forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous RNN states + :return: batch of hidden state sequences (B, Tmax, eprojs) + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens)) + xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True) + self.nbrnn.flatten_parameters() + if prev_state is not None and self.nbrnn.bidirectional: + # We assume that when previous state is passed, it means that we're streaming the input + # and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction) + prev_state = reset_backward_rnn_state(prev_state) + ys, states = self.nbrnn(xs_pack, hx=prev_state) + # ys: utt list of frame x cdim x 2 (2: means bidirectional) + ys_pad, ilens = pad_packed_sequence(ys, batch_first=True) + # (sum _utt frame_utt) x dim + projected = torch.tanh(self.l_last( + ys_pad.contiguous().view(-1, ys_pad.size(2)))) + xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1) + return xs_pad, ilens, states # x: utt list of frame x dim + + +def reset_backward_rnn_state(states): + """Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs""" + if isinstance(states, (list, tuple)): + for state in states: + state[1::2] = 0. + else: + states[1::2] = 0. + return states + + +class VGG2L(torch.nn.Module): + """VGG-like module + + :param int in_channel: number of input channels + """ + + def __init__(self, in_channel=1, downsample=True): + super(VGG2L, self).__init__() + # CNN layer (VGG motivated) + self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1) + self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1) + self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1) + self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1) + + self.in_channel = in_channel + self.downsample = downsample + if downsample: + self.stride = 2 + else: + self.stride = 1 + + def forward(self, xs_pad, ilens, **kwargs): + """VGG2L forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample + :rtype: torch.Tensor + """ + logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens)) + + # x: utt x frame x dim + # xs_pad = F.pad_sequence(xs_pad) + + # x: utt x 1 (input channel num) x frame x dim + xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel, + xs_pad.size(2) // self.in_channel).transpose(1, 2) + + # NOTE: max_pool1d ? + xs_pad = F.relu(self.conv1_1(xs_pad)) + xs_pad = F.relu(self.conv1_2(xs_pad)) + if self.downsample: + xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True) + + xs_pad = F.relu(self.conv2_1(xs_pad)) + xs_pad = F.relu(self.conv2_2(xs_pad)) + if self.downsample: + xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True) + if torch.is_tensor(ilens): + ilens = ilens.cpu().numpy() + else: + ilens = np.array(ilens, dtype=np.float32) + if self.downsample: + ilens = np.array(np.ceil(ilens / 2), dtype=np.int64) + ilens = np.array( + np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist() + + # x: utt_list of frame (remove zeropaded frames) x (input channel num x dim) + xs_pad = xs_pad.transpose(1, 2) + xs_pad = xs_pad.contiguous().view( + xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3)) + return xs_pad, ilens, None # no state in this layer + + +class Encoder(torch.nn.Module): + """Encoder module + + :param str etype: type of encoder network + :param int idim: number of dimensions of encoder network + :param int elayers: number of layers of encoder network + :param int eunits: number of lstm units of encoder network + :param int eprojs: number of projection units of encoder network + :param np.ndarray subsample: list of subsampling numbers + :param float dropout: dropout rate + :param int in_channel: number of input channels + """ + + def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1): + super(Encoder, self).__init__() + typ = etype.lstrip("vgg").rstrip("p") + if typ not in ['lstm', 'gru', 'blstm', 'bgru']: + logging.error("Error: need to specify an appropriate encoder architecture") + + if etype.startswith("vgg"): + if etype[-1] == "p": + self.enc = torch.nn.ModuleList([VGG2L(in_channel), + RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits, + eprojs, + subsample, dropout, typ=typ)]) + logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder') + else: + self.enc = torch.nn.ModuleList([VGG2L(in_channel), + RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits, + eprojs, + dropout, typ=typ)]) + logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder') + else: + if etype[-1] == "p": + self.enc = torch.nn.ModuleList( + [RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)]) + logging.info(typ.upper() + ' with every-layer projection for encoder') + else: + self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)]) + logging.info(typ.upper() + ' without projection for encoder') + + def forward(self, xs_pad, ilens, prev_states=None): + """Encoder forward + + :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D) + :param torch.Tensor ilens: batch of lengths of input sequences (B) + :param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...) + :return: batch of hidden state sequences (B, Tmax, eprojs) + :rtype: torch.Tensor + """ + if prev_states is None: + prev_states = [None] * len(self.enc) + assert len(prev_states) == len(self.enc) + + current_states = [] + for module, prev_state in zip(self.enc, prev_states): + xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state) + current_states.append(states) + + # make mask to remove bias value in padded part + mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1)) + + return xs_pad.masked_fill(mask, 0.0), ilens, current_states + + +def encoder_for(args, idim, subsample): + """Instantiates an encoder module given the program arguments + + :param Namespace args: The arguments + :param int or List of integer idim: dimension of input, e.g. 83, or + List of dimensions of inputs, e.g. [83,83] + :param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or + List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]] + :rtype torch.nn.Module + :return: The encoder module + """ + num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility + if num_encs == 1: + # compatible with single encoder asr mode + return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate) + elif num_encs >= 1: + enc_list = torch.nn.ModuleList() + for idx in range(num_encs): + enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx], + args.dropout_rate[idx]) + enc_list.append(enc) + return enc_list + else: + raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs)) diff --git a/ppg_extractor/frontend.py b/ppg_extractor/frontend.py new file mode 100644 index 0000000..32549ed --- /dev/null +++ b/ppg_extractor/frontend.py @@ -0,0 +1,115 @@ +import copy +from typing import Tuple +import numpy as np +import torch +from torch_complex.tensor import ComplexTensor + +from .log_mel import LogMel +from .stft import Stft + + +class DefaultFrontend(torch.nn.Module): + """Conventional frontend structure for ASR + + Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN + """ + + def __init__( + self, + fs: 16000, + n_fft: int = 1024, + win_length: int = 800, + hop_length: int = 160, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: bool = True, + n_mels: int = 80, + fmin: int = None, + fmax: int = None, + htk: bool = False, + norm=1, + frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend), + kaldi_padding_mode=False, + downsample_rate: int = 1, + ): + super().__init__() + self.downsample_rate = downsample_rate + + # Deepcopy (In general, dict shouldn't be used as default arg) + frontend_conf = copy.deepcopy(frontend_conf) + + self.stft = Stft( + n_fft=n_fft, + win_length=win_length, + hop_length=hop_length, + center=center, + pad_mode=pad_mode, + normalized=normalized, + onesided=onesided, + kaldi_padding_mode=kaldi_padding_mode + ) + if frontend_conf is not None: + self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf) + else: + self.frontend = None + + self.logmel = LogMel( + fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm, + ) + self.n_mels = n_mels + + def output_size(self) -> int: + return self.n_mels + + def forward( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # 1. Domain-conversion: e.g. Stft: time -> time-freq + input_stft, feats_lens = self.stft(input, input_lengths) + + assert input_stft.dim() >= 4, input_stft.shape + # "2" refers to the real/imag parts of Complex + assert input_stft.shape[-1] == 2, input_stft.shape + + # Change torch.Tensor to ComplexTensor + # input_stft: (..., F, 2) -> (..., F) + input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1]) + + # 2. [Option] Speech enhancement + if self.frontend is not None: + assert isinstance(input_stft, ComplexTensor), type(input_stft) + # input_stft: (Batch, Length, [Channel], Freq) + input_stft, _, mask = self.frontend(input_stft, feats_lens) + + # 3. [Multi channel case]: Select a channel + if input_stft.dim() == 4: + # h: (B, T, C, F) -> h: (B, T, F) + if self.training: + # Select 1ch randomly + ch = np.random.randint(input_stft.size(2)) + input_stft = input_stft[:, :, ch, :] + else: + # Use the first channel + input_stft = input_stft[:, :, 0, :] + + # 4. STFT -> Power spectrum + # h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F) + input_power = input_stft.real ** 2 + input_stft.imag ** 2 + + # 5. Feature transform e.g. Stft -> Log-Mel-Fbank + # input_power: (Batch, [Channel,] Length, Freq) + # -> input_feats: (Batch, Length, Dim) + input_feats, _ = self.logmel(input_power, feats_lens) + + # NOTE(sx): pad + max_len = input_feats.size(1) + if self.downsample_rate > 1 and max_len % self.downsample_rate != 0: + padding = self.downsample_rate - max_len % self.downsample_rate + # print("Logmel: ", input_feats.size()) + input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding), + "constant", 0) + # print("Logmel(after padding): ",input_feats.size()) + feats_lens[torch.argmax(feats_lens)] = max_len + padding + + return input_feats, feats_lens diff --git a/ppg_extractor/log_mel.py b/ppg_extractor/log_mel.py new file mode 100644 index 0000000..1e3b87d --- /dev/null +++ b/ppg_extractor/log_mel.py @@ -0,0 +1,74 @@ +import librosa +import numpy as np +import torch +from typing import Tuple + +from .nets_utils import make_pad_mask + + +class LogMel(torch.nn.Module): + """Convert STFT to fbank feats + + The arguments is same as librosa.filters.mel + + Args: + fs: number > 0 [scalar] sampling rate of the incoming signal + n_fft: int > 0 [scalar] number of FFT components + n_mels: int > 0 [scalar] number of Mel bands to generate + fmin: float >= 0 [scalar] lowest frequency (in Hz) + fmax: float >= 0 [scalar] highest frequency (in Hz). + If `None`, use `fmax = fs / 2.0` + htk: use HTK formula instead of Slaney + norm: {None, 1, np.inf} [scalar] + if 1, divide the triangular mel weights by the width of the mel band + (area normalization). Otherwise, leave all the triangles aiming for + a peak value of 1.0 + + """ + + def __init__( + self, + fs: int = 16000, + n_fft: int = 512, + n_mels: int = 80, + fmin: float = None, + fmax: float = None, + htk: bool = False, + norm=1, + ): + super().__init__() + + fmin = 0 if fmin is None else fmin + fmax = fs / 2 if fmax is None else fmax + _mel_options = dict( + sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm + ) + self.mel_options = _mel_options + + # Note(kamo): The mel matrix of librosa is different from kaldi. + melmat = librosa.filters.mel(**_mel_options) + # melmat: (D2, D1) -> (D1, D2) + self.register_buffer("melmat", torch.from_numpy(melmat.T).float()) + inv_mel = np.linalg.pinv(melmat) + self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float()) + + def extra_repr(self): + return ", ".join(f"{k}={v}" for k, v in self.mel_options.items()) + + def forward( + self, feat: torch.Tensor, ilens: torch.Tensor = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2) + mel_feat = torch.matmul(feat, self.melmat) + + logmel_feat = (mel_feat + 1e-20).log() + # Zero padding + if ilens is not None: + logmel_feat = logmel_feat.masked_fill( + make_pad_mask(ilens, logmel_feat, 1), 0.0 + ) + else: + ilens = feat.new_full( + [feat.size(0)], fill_value=feat.size(1), dtype=torch.long + ) + return logmel_feat, ilens diff --git a/ppg_extractor/nets_utils.py b/ppg_extractor/nets_utils.py new file mode 100644 index 0000000..6db064b --- /dev/null +++ b/ppg_extractor/nets_utils.py @@ -0,0 +1,465 @@ +# -*- coding: utf-8 -*- + +"""Network related utility tools.""" + +import logging +from typing import Dict + +import numpy as np +import torch + + +def to_device(m, x): + """Send tensor into the device of the module. + + Args: + m (torch.nn.Module): Torch module. + x (Tensor): Torch tensor. + + Returns: + Tensor: Torch tensor located in the same place as torch module. + + """ + assert isinstance(m, torch.nn.Module) + device = next(m.parameters()).device + return x.to(device) + + +def pad_list(xs, pad_value): + """Perform padding for the list of tensors. + + Args: + xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. + pad_value (float): Value for padding. + + Returns: + Tensor: Padded tensor (B, Tmax, `*`). + + Examples: + >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] + >>> x + [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] + >>> pad_list(x, 0) + tensor([[1., 1., 1., 1.], + [1., 1., 0., 0.], + [1., 0., 0., 0.]]) + + """ + n_batch = len(xs) + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, :xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + + """ + if length_dim == 0: + raise ValueError('length_dim cannot be 0: {}'.format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple(slice(None) if i in (0, length_dim) else None + for i in range(xs.dim())) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. See the example. + + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + + Examples: + With only lengths. + + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + + With the reference tensor. + + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + With the reference tensor and dimension indicator. + + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def mask_by_length(xs, lengths, fill=0): + """Mask tensor according to length. + + Args: + xs (Tensor): Batch of input tensor (B, `*`). + lengths (LongTensor or List): Batch of lengths (B,). + fill (int or float): Value to fill masked part. + + Returns: + Tensor: Batch of masked input tensor (B, `*`). + + Examples: + >>> x = torch.arange(5).repeat(3, 1) + 1 + >>> x + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 4, 5], + [1, 2, 3, 4, 5]]) + >>> lengths = [5, 3, 2] + >>> mask_by_length(x, lengths) + tensor([[1, 2, 3, 4, 5], + [1, 2, 3, 0, 0], + [1, 2, 0, 0, 0]]) + + """ + assert xs.size(0) == len(lengths) + ret = xs.data.new(*xs.size()).fill_(fill) + for i, l in enumerate(lengths): + ret[i, :l] = xs[i, :l] + return ret + + +def th_accuracy(pad_outputs, pad_targets, ignore_label): + """Calculate accuracy. + + Args: + pad_outputs (Tensor): Prediction tensors (B * Lmax, D). + pad_targets (LongTensor): Target label tensors (B, Lmax, D). + ignore_label (int): Ignore label id. + + Returns: + float: Accuracy value (0.0 - 1.0). + + """ + pad_pred = pad_outputs.view( + pad_targets.size(0), + pad_targets.size(1), + pad_outputs.size(1)).argmax(2) + mask = pad_targets != ignore_label + numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) + denominator = torch.sum(mask) + return float(numerator) / float(denominator) + + +def to_torch_tensor(x): + """Change to torch.Tensor or ComplexTensor from numpy.ndarray. + + Args: + x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict. + + Returns: + Tensor or ComplexTensor: Type converted inputs. + + Examples: + >>> xs = np.ones(3, dtype=np.float32) + >>> xs = to_torch_tensor(xs) + tensor([1., 1., 1.]) + >>> xs = torch.ones(3, 4, 5) + >>> assert to_torch_tensor(xs) is xs + >>> xs = {'real': xs, 'imag': xs} + >>> to_torch_tensor(xs) + ComplexTensor( + Real: + tensor([1., 1., 1.]) + Imag; + tensor([1., 1., 1.]) + ) + + """ + # If numpy, change to torch tensor + if isinstance(x, np.ndarray): + if x.dtype.kind == 'c': + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + return ComplexTensor(x) + else: + return torch.from_numpy(x) + + # If {'real': ..., 'imag': ...}, convert to ComplexTensor + elif isinstance(x, dict): + # Dynamically importing because torch_complex requires python3 + from torch_complex.tensor import ComplexTensor + + if 'real' not in x or 'imag' not in x: + raise ValueError("has 'real' and 'imag' keys: {}".format(list(x))) + # Relative importing because of using python3 syntax + return ComplexTensor(x['real'], x['imag']) + + # If torch.Tensor, as it is + elif isinstance(x, torch.Tensor): + return x + + else: + error = ("x must be numpy.ndarray, torch.Tensor or a dict like " + "{{'real': torch.Tensor, 'imag': torch.Tensor}}, " + "but got {}".format(type(x))) + try: + from torch_complex.tensor import ComplexTensor + except Exception: + # If PY2 + raise ValueError(error) + else: + # If PY3 + if isinstance(x, ComplexTensor): + return x + else: + raise ValueError(error) + + +def get_subsample(train_args, mode, arch): + """Parse the subsampling factors from the training args for the specified `mode` and `arch`. + + Args: + train_args: argument Namespace containing options. + mode: one of ('asr', 'mt', 'st') + arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer') + + Returns: + np.ndarray / List[np.ndarray]: subsampling factors. + """ + if arch == 'transformer': + return np.array([1]) + + elif mode == 'mt' and arch == 'rnn': + # +1 means input (+1) and layers outputs (train_args.elayer) + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + logging.warning('Subsampling is not performed for machine translation.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \ + (mode == 'mt' and arch == 'rnn') or \ + (mode == 'st' and arch == 'rnn'): + subsample = np.ones(train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mix': + subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int) + if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"): + ss = train_args.subsample.split("_") + for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.') + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + return subsample + + elif mode == 'asr' and arch == 'rnn_mulenc': + subsample_list = [] + for idx in range(train_args.num_encs): + subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int) + if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"): + ss = train_args.subsample[idx].split("_") + for j in range(min(train_args.elayers[idx] + 1, len(ss))): + subsample[j] = int(ss[j]) + else: + logging.warning( + 'Encoder %d: Subsampling is not performed for vgg*. ' + 'It is performed in max pooling layers at CNN.', idx + 1) + logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) + subsample_list.append(subsample) + return subsample_list + + else: + raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch)) + + +def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]): + """Replace keys of old prefix with new prefix in state dict.""" + # need this list not to break the dict iterator + old_keys = [k for k in state_dict if k.startswith(old_prefix)] + if len(old_keys) > 0: + logging.warning(f'Rename: {old_prefix} -> {new_prefix}') + for k in old_keys: + v = state_dict.pop(k) + new_k = k.replace(old_prefix, new_prefix) + state_dict[new_k] = v + +def get_activation(act): + """Return activation function.""" + # Lazy load to avoid unused import + from .encoder.swish import Swish + + activation_funcs = { + "hardtanh": torch.nn.Hardtanh, + "relu": torch.nn.ReLU, + "selu": torch.nn.SELU, + "swish": Swish, + } + + return activation_funcs[act]() diff --git a/ppg_extractor/stft.py b/ppg_extractor/stft.py new file mode 100644 index 0000000..06b879e --- /dev/null +++ b/ppg_extractor/stft.py @@ -0,0 +1,118 @@ +from typing import Optional +from typing import Tuple +from typing import Union + +import torch + +from .nets_utils import make_pad_mask + + +class Stft(torch.nn.Module): + def __init__( + self, + n_fft: int = 512, + win_length: Union[int, None] = 512, + hop_length: int = 128, + center: bool = True, + pad_mode: str = "reflect", + normalized: bool = False, + onesided: bool = True, + kaldi_padding_mode=False, + ): + super().__init__() + self.n_fft = n_fft + if win_length is None: + self.win_length = n_fft + else: + self.win_length = win_length + self.hop_length = hop_length + self.center = center + self.pad_mode = pad_mode + self.normalized = normalized + self.onesided = onesided + self.kaldi_padding_mode = kaldi_padding_mode + if self.kaldi_padding_mode: + self.win_length = 400 + + def extra_repr(self): + return ( + f"n_fft={self.n_fft}, " + f"win_length={self.win_length}, " + f"hop_length={self.hop_length}, " + f"center={self.center}, " + f"pad_mode={self.pad_mode}, " + f"normalized={self.normalized}, " + f"onesided={self.onesided}" + ) + + def forward( + self, input: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """STFT forward function. + + Args: + input: (Batch, Nsamples) or (Batch, Nsample, Channels) + ilens: (Batch) + Returns: + output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2) + + """ + bs = input.size(0) + if input.dim() == 3: + multi_channel = True + # input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample) + input = input.transpose(1, 2).reshape(-1, input.size(1)) + else: + multi_channel = False + + # output: (Batch, Freq, Frames, 2=real_imag) + # or (Batch, Channel, Freq, Frames, 2=real_imag) + if not self.kaldi_padding_mode: + output = torch.stft( + input, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=self.center, + pad_mode=self.pad_mode, + normalized=self.normalized, + onesided=self.onesided, + return_complex=False + ) + else: + # NOTE(sx): Use Kaldi-fasion padding, maybe wrong + num_pads = self.n_fft - self.win_length + input = torch.nn.functional.pad(input, (num_pads, 0)) + output = torch.stft( + input, + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + center=False, + pad_mode=self.pad_mode, + normalized=self.normalized, + onesided=self.onesided, + return_complex=False + ) + + # output: (Batch, Freq, Frames, 2=real_imag) + # -> (Batch, Frames, Freq, 2=real_imag) + output = output.transpose(1, 2) + if multi_channel: + # output: (Batch * Channel, Frames, Freq, 2=real_imag) + # -> (Batch, Frame, Channel, Freq, 2=real_imag) + output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose( + 1, 2 + ) + + if ilens is not None: + if self.center: + pad = self.win_length // 2 + ilens = ilens + 2 * pad + olens = torch.div(ilens - self.win_length, self.hop_length, rounding_mode='floor') + 1 + # olens = ilens - self.win_length // self.hop_length + 1 + output.masked_fill_(make_pad_mask(olens, output, 1), 0.0) + else: + olens = None + + return output, olens diff --git a/ppg_extractor/utterance_mvn.py b/ppg_extractor/utterance_mvn.py new file mode 100644 index 0000000..37fb0c1 --- /dev/null +++ b/ppg_extractor/utterance_mvn.py @@ -0,0 +1,82 @@ +from typing import Tuple + +import torch + +from .nets_utils import make_pad_mask + + +class UtteranceMVN(torch.nn.Module): + def __init__( + self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20, + ): + super().__init__() + self.norm_means = norm_means + self.norm_vars = norm_vars + self.eps = eps + + def extra_repr(self): + return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}" + + def forward( + self, x: torch.Tensor, ilens: torch.Tensor = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward function + + Args: + x: (B, L, ...) + ilens: (B,) + + """ + return utterance_mvn( + x, + ilens, + norm_means=self.norm_means, + norm_vars=self.norm_vars, + eps=self.eps, + ) + + +def utterance_mvn( + x: torch.Tensor, + ilens: torch.Tensor = None, + norm_means: bool = True, + norm_vars: bool = False, + eps: float = 1.0e-20, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply utterance mean and variance normalization + + Args: + x: (B, T, D), assumed zero padded + ilens: (B,) + norm_means: + norm_vars: + eps: + + """ + if ilens is None: + ilens = x.new_full([x.size(0)], x.size(1)) + ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)]) + # Zero padding + if x.requires_grad: + x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0) + else: + x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0) + # mean: (B, 1, D) + mean = x.sum(dim=1, keepdim=True) / ilens_ + + if norm_means: + x -= mean + + if norm_vars: + var = x.pow(2).sum(dim=1, keepdim=True) / ilens_ + std = torch.clamp(var.sqrt(), min=eps) + x = x / std.sqrt() + return x, ilens + else: + if norm_vars: + y = x - mean + y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0) + var = y.pow(2).sum(dim=1, keepdim=True) / ilens_ + std = torch.clamp(var.sqrt(), min=eps) + x /= std + return x, ilens diff --git a/pre4ppg.py b/pre4ppg.py new file mode 100644 index 0000000..fcfa0fa --- /dev/null +++ b/pre4ppg.py @@ -0,0 +1,49 @@ +from pathlib import Path +import argparse + +from ppg2mel.preprocess import preprocess_dataset +from pathlib import Path +import argparse + +recognized_datasets = [ + "aidatatang_200zh", + "aidatatang_200zh_s", # sample +] + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Preprocesses audio files from datasets, to be used by the " + "ppg2mel model for training.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument("datasets_root", type=Path, help=\ + "Path to the directory containing your datasets.") + parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\ + "Name of the dataset to process, allowing values: aidatatang_200zh.") + parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\ + "Path to the output directory that will contain the mel spectrograms, the audios and the " + "embeds. Defaults to /PPGVC/ppg2mel/") + parser.add_argument("-n", "--n_processes", type=int, default=8, help=\ + "Number of processes in parallel.") + # parser.add_argument("-s", "--skip_existing", action="store_true", help=\ + # "Whether to overwrite existing files with the same name. Useful if the preprocessing was " + # "interrupted. ") + # parser.add_argument("--hparams", type=str, default="", help=\ + # "Hyperparameter overrides as a comma-separated list of name-value pairs") + # parser.add_argument("--no_trim", action="store_true", help=\ + # "Preprocess audio without trimming silences (not recommended).") + parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\ + "Path your trained ppg encoder model.") + parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\ + "Path your trained speaker encoder model.") + args = parser.parse_args() + + assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one' + + # Create directories + assert args.datasets_root.exists() + if not hasattr(args, "out_dir"): + args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel") + args.out_dir.mkdir(exist_ok=True, parents=True) + + preprocess_dataset(**vars(args)) diff --git a/requirements.txt b/requirements.txt index 02a3c5e..1091207 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,10 @@ webrtcvad; platform_system != "Windows" pypinyin flask flask_wtf -flask_cors +flask_cors==3.0.10 gevent==21.8.0 flask_restx -tensorboard \ No newline at end of file +tensorboard +PyYAML==5.4.1 +torch_complex +espnet \ No newline at end of file diff --git a/run.py b/run.py new file mode 100644 index 0000000..170f9db --- /dev/null +++ b/run.py @@ -0,0 +1,142 @@ +import time +import os +import argparse +import torch +import numpy as np +import glob +from pathlib import Path +from tqdm import tqdm +from ppg_extractor import load_model +import librosa +import soundfile as sf +from utils.load_yaml import HpsYaml + +from encoder.audio import preprocess_wav +from encoder import inference as speacker_encoder +from vocoder.hifigan import inference as vocoder +from ppg2mel import MelDecoderMOLv2 +from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv + + +def _build_ppg2mel_model(model_config, model_file, device): + ppg2mel_model = MelDecoderMOLv2( + **model_config["model"] + ).to(device) + ckpt = torch.load(model_file, map_location=device) + ppg2mel_model.load_state_dict(ckpt["model"]) + ppg2mel_model.eval() + return ppg2mel_model + + +@torch.no_grad() +def convert(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1] + + # Build models + print("Load PPG-model, PPG2Mel-model, Vocoder-model...") + ppg_model = load_model( + Path('./ppg_extractor/saved_models/24epoch.pt'), + device, + ) + ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device) + # vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json") + vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt') + # Data related + ref_wav_path = args.ref_wav_path + ref_wav = preprocess_wav(ref_wav_path) + ref_fid = os.path.basename(ref_wav_path)[:-4] + + # TODO: specify encoder + speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt")) + ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav) + ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device) + ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav))) + + source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav")) + print(f"Number of source utterances: {len(source_file_list)}.") + + total_rtf = 0.0 + cnt = 0 + for src_wav_path in tqdm(source_file_list): + # Load the audio to a numpy array: + src_wav, _ = librosa.load(src_wav_path, sr=16000) + src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device) + src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device) + ppg = ppg_model(src_wav_tensor, src_wav_lengths) + + lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True) + min_len = min(ppg.shape[1], len(lf0_uv)) + + ppg = ppg[:, :min_len] + lf0_uv = lf0_uv[:min_len] + + start = time.time() + _, mel_pred, att_ws = ppg2mel_model.inference( + ppg, + logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device), + spembs=ref_spk_dvec, + ) + src_fid = os.path.basename(src_wav_path)[:-4] + wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav" + mel_len = mel_pred.shape[0] + rtf = (time.time() - start) / (0.01 * mel_len) + total_rtf += rtf + cnt += 1 + # continue + mel_pred= mel_pred.transpose(0, 1) + y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu()) + sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16") + + print("RTF:") + print(total_rtf / cnt) + + +def get_parser(): + parser = argparse.ArgumentParser(description="Conversion from wave input") + parser.add_argument( + "--wav_dir", + type=str, + default=None, + required=True, + help="Source wave directory.", + ) + parser.add_argument( + "--ref_wav_path", + type=str, + required=True, + help="Reference wave file path.", + ) + parser.add_argument( + "--ppg2mel_model_train_config", "-c", + type=str, + default=None, + required=True, + help="Training config file (yaml file)", + ) + parser.add_argument( + "--ppg2mel_model_file", "-m", + type=str, + default=None, + required=True, + help="ppg2mel model checkpoint file path" + ) + parser.add_argument( + "--output_dir", "-o", + type=str, + default="vc_gens_vctk_oneshot", + help="Output folder to save the converted wave." + ) + + return parser + +def main(): + parser = get_parser() + args = parser.parse_args() + convert(args) + +if __name__ == "__main__": + main() diff --git a/toolbox/__init__.py b/toolbox/__init__.py index 827833a..30f2865 100644 --- a/toolbox/__init__.py +++ b/toolbox/__init__.py @@ -3,16 +3,17 @@ from encoder import inference as encoder from synthesizer.inference import Synthesizer from vocoder.wavernn import inference as rnn_vocoder from vocoder.hifigan import inference as gan_vocoder +import ppg_extractor as extractor +import ppg2mel as convertor from pathlib import Path from time import perf_counter as timer from toolbox.utterance import Utterance +from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv import numpy as np import traceback import sys import torch -import librosa import re -from audioread.exceptions import NoBackendError # 默认使用wavernn vocoder = rnn_vocoder @@ -49,14 +50,20 @@ recognized_datasets = [ MAX_WAVES = 15 class Toolbox: - def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support): + def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode): self.no_mp3_support = no_mp3_support + self.vc_mode = vc_mode sys.excepthook = self.excepthook self.datasets_root = datasets_root self.utterances = set() self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav self.synthesizer = None # type: Synthesizer + + # for ppg-based voice conversion + self.extractor = None + self.convertor = None # ppg2mel + self.current_wav = None self.waves_list = [] self.waves_count = 0 @@ -70,9 +77,9 @@ class Toolbox: self.trim_silences = False # Initialize the events and the interface - self.ui = UI() + self.ui = UI(vc_mode) self.style_idx = 0 - self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed) + self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed) self.setup_events() self.ui.start() @@ -96,7 +103,11 @@ class Toolbox: self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder) def func(): self.synthesizer = None - self.ui.synthesizer_box.currentIndexChanged.connect(func) + if self.vc_mode: + self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor) + else: + self.ui.synthesizer_box.currentIndexChanged.connect(func) + self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder) # Utterance selection @@ -109,6 +120,11 @@ class Toolbox: self.ui.stop_button.clicked.connect(self.ui.stop) self.ui.record_button.clicked.connect(self.record) + # Source Utterance selection + if self.vc_mode: + func = lambda: self.load_soruce_button(self.ui.selected_utterance) + self.ui.load_soruce_button.clicked.connect(func) + #Audio self.ui.setup_audio_devices(Synthesizer.sample_rate) @@ -120,12 +136,17 @@ class Toolbox: self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav) # Generation - func = lambda: self.synthesize() or self.vocode() - self.ui.generate_button.clicked.connect(func) - self.ui.synthesize_button.clicked.connect(self.synthesize) self.ui.vocode_button.clicked.connect(self.vocode) self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox) + if self.vc_mode: + func = lambda: self.convert() or self.vocode() + self.ui.convert_button.clicked.connect(func) + else: + func = lambda: self.synthesize() or self.vocode() + self.ui.generate_button.clicked.connect(func) + self.ui.synthesize_button.clicked.connect(self.synthesize) + # UMAP legend self.ui.clear_button.clicked.connect(self.clear_utterances) @@ -138,9 +159,9 @@ class Toolbox: def replay_last_wav(self): self.ui.play(self.current_wav, Synthesizer.sample_rate) - def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed): + def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed): self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True) - self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir) + self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode) self.ui.populate_gen_options(seed, self.trim_silences) def load_from_browser(self, fpath=None): @@ -171,7 +192,10 @@ class Toolbox: self.ui.log("Loaded %s" % name) self.add_real_utterance(wav, name, speaker_name) - + + def load_soruce_button(self, utterance: Utterance): + self.selected_source_utterance = utterance + def record(self): wav = self.ui.record_one(encoder.sampling_rate, 5) if wav is None: @@ -196,7 +220,7 @@ class Toolbox: # Add the utterance utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False) self.utterances.add(utterance) - self.ui.register_utterance(utterance) + self.ui.register_utterance(utterance, self.vc_mode) # Plot it self.ui.draw_embed(embed, name, "current") @@ -269,7 +293,7 @@ class Toolbox: self.ui.set_loading(i, seq_len) if self.ui.current_vocoder_fpath is not None: self.ui.log("") - wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress) + wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress) else: self.ui.log("Waveform generation with Griffin-Lim... ") wav = Synthesizer.griffin_lim(spec) @@ -280,7 +304,7 @@ class Toolbox: b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size) b_starts = np.concatenate(([0], b_ends[:-1])) wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)] - breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks) + breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks) wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)]) # Trim excessive silences @@ -289,7 +313,7 @@ class Toolbox: # Play it wav = wav / np.abs(wav).max() * 0.97 - self.ui.play(wav, Synthesizer.sample_rate) + self.ui.play(wav, sample_rate) # Name it (history displayed in combobox) # TODO better naming for the combobox items? @@ -331,6 +355,68 @@ class Toolbox: self.ui.draw_embed(embed, name, "generated") self.ui.draw_umap_projections(self.utterances) + def convert(self): + self.ui.log("Extract PPG and Converting...") + self.ui.set_loading(1) + + # Init + if self.convertor is None: + self.init_convertor() + if self.extractor is None: + self.init_extractor() + + src_wav = self.selected_source_utterance.wav + + # Compute the ppg + if not self.extractor is None: + ppg = self.extractor.extract_from_wav(src_wav) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ref_wav = self.ui.selected_utterance.wav + ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav))) + lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True) + min_len = min(ppg.shape[1], len(lf0_uv)) + ppg = ppg[:, :min_len] + lf0_uv = lf0_uv[:min_len] + _, mel_pred, att_ws = self.convertor.inference( + ppg, + logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device), + spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device), + ) + mel_pred= mel_pred.transpose(0, 1) + breaks = [mel_pred.shape[1]] + mel_pred= mel_pred.detach().cpu().numpy() + self.ui.draw_spec(mel_pred, "generated") + self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None) + self.ui.set_loading(0) + + def init_extractor(self): + if self.ui.current_extractor_fpath is None: + return + model_fpath = self.ui.current_extractor_fpath + self.ui.log("Loading the extractor %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + self.extractor = extractor.load_model(model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + + def init_convertor(self): + if self.ui.current_convertor_fpath is None: + return + model_fpath = self.ui.current_convertor_fpath + # search a config file + model_config_fpaths = list(model_fpath.parent.rglob("*.yaml")) + if self.ui.current_convertor_fpath is None: + return + model_config_fpath = model_config_fpaths[0] + self.ui.log("Loading the convertor %s... " % model_fpath) + self.ui.set_loading(1) + start = timer() + self.convertor = convertor.load_model(model_config_fpath, model_fpath) + self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") + self.ui.set_loading(0) + def init_encoder(self): model_fpath = self.ui.current_encoder_fpath @@ -358,12 +444,16 @@ class Toolbox: # Case of Griffin-lim if model_fpath is None: return - - # Sekect vocoder based on model name + model_config_fpath = None if model_fpath.name[0] == "g": vocoder = gan_vocoder self.ui.log("set hifigan as vocoder") + # search a config file + model_config_fpaths = list(model_fpath.parent.rglob("*.json")) + if self.ui.current_extractor_fpath is None: + return + model_config_fpath = model_config_fpaths[0] else: vocoder = rnn_vocoder self.ui.log("set wavernn as vocoder") @@ -371,7 +461,7 @@ class Toolbox: self.ui.log("Loading the vocoder %s... " % model_fpath) self.ui.set_loading(1) start = timer() - vocoder.load_model(model_fpath) + vocoder.load_model(model_fpath, model_config_fpath) self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append") self.ui.set_loading(0) diff --git a/toolbox/ui.py b/toolbox/ui.py index 34f8efe..fe51e73 100644 --- a/toolbox/ui.py +++ b/toolbox/ui.py @@ -326,30 +326,51 @@ class UI(QDialog): def current_vocoder_fpath(self): return self.vocoder_box.itemData(self.vocoder_box.currentIndex()) + @property + def current_extractor_fpath(self): + return self.extractor_box.itemData(self.extractor_box.currentIndex()) + + @property + def current_convertor_fpath(self): + return self.convertor_box.itemData(self.convertor_box.currentIndex()) + def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path, - vocoder_models_dir: Path): + vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool): # Encoder encoder_fpaths = list(encoder_models_dir.glob("*.pt")) if len(encoder_fpaths) == 0: raise Exception("No encoder models found in %s" % encoder_models_dir) self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths]) - # Synthesizer - synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt")) - if len(synthesizer_fpaths) == 0: - raise Exception("No synthesizer models found in %s" % synthesizer_models_dir) - self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths]) + if vc_mode: + # Extractor + extractor_fpaths = list(extractor_models_dir.glob("*.pt")) + if len(extractor_fpaths) == 0: + self.log("No extractor models found in %s" % extractor_fpaths) + self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths]) + + # Convertor + convertor_fpaths = list(convertor_models_dir.glob("*.pth")) + if len(convertor_fpaths) == 0: + self.log("No convertor models found in %s" % convertor_fpaths) + self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths]) + else: + # Synthesizer + synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt")) + if len(synthesizer_fpaths) == 0: + raise Exception("No synthesizer models found in %s" % synthesizer_models_dir) + self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths]) # Vocoder vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt")) vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)] self.repopulate_box(self.vocoder_box, vocoder_items) - + @property def selected_utterance(self): return self.utterance_history.itemData(self.utterance_history.currentIndex()) - def register_utterance(self, utterance: Utterance): + def register_utterance(self, utterance: Utterance, vc_mode): self.utterance_history.blockSignals(True) self.utterance_history.insertItem(0, utterance.name, utterance) self.utterance_history.setCurrentIndex(0) @@ -359,8 +380,11 @@ class UI(QDialog): self.utterance_history.removeItem(self.max_saved_utterances) self.play_button.setDisabled(False) - self.generate_button.setDisabled(False) - self.synthesize_button.setDisabled(False) + if vc_mode: + self.convert_button.setDisabled(False) + else: + self.generate_button.setDisabled(False) + self.synthesize_button.setDisabled(False) def log(self, line, mode="newline"): if mode == "newline": @@ -402,7 +426,7 @@ class UI(QDialog): else: self.seed_textbox.setEnabled(False) - def reset_interface(self): + def reset_interface(self, vc_mode): self.draw_embed(None, None, "current") self.draw_embed(None, None, "generated") self.draw_spec(None, "current") @@ -410,14 +434,17 @@ class UI(QDialog): self.draw_umap_projections(set()) self.set_loading(0) self.play_button.setDisabled(True) - self.generate_button.setDisabled(True) - self.synthesize_button.setDisabled(True) + if vc_mode: + self.convert_button.setDisabled(True) + else: + self.generate_button.setDisabled(True) + self.synthesize_button.setDisabled(True) self.vocode_button.setDisabled(True) self.replay_wav_button.setDisabled(True) self.export_wav_button.setDisabled(True) [self.log("") for _ in range(self.max_log_lines)] - def __init__(self): + def __init__(self, vc_mode): ## Initialize the application self.app = QApplication(sys.argv) super().__init__(None) @@ -469,7 +496,7 @@ class UI(QDialog): source_groupbox = QGroupBox('Source(源音频)') source_layout = QGridLayout() source_groupbox.setLayout(source_layout) - browser_layout.addWidget(source_groupbox, i, 0, 1, 4) + browser_layout.addWidget(source_groupbox, i, 0, 1, 5) self.dataset_box = QComboBox() source_layout.addWidget(QLabel("Dataset(数据集):"), i, 0) @@ -510,25 +537,35 @@ class UI(QDialog): browser_layout.addWidget(self.play_button, i, 2) self.stop_button = QPushButton("Stop(暂停)") browser_layout.addWidget(self.stop_button, i, 3) + if vc_mode: + self.load_soruce_button = QPushButton("Select(选择为被转换的语音输入)") + browser_layout.addWidget(self.load_soruce_button, i, 4) i += 1 model_groupbox = QGroupBox('Models(模型选择)') model_layout = QHBoxLayout() model_groupbox.setLayout(model_layout) - browser_layout.addWidget(model_groupbox, i, 0, 1, 4) + browser_layout.addWidget(model_groupbox, i, 0, 2, 5) # Model and audio output selection self.encoder_box = QComboBox() model_layout.addWidget(QLabel("Encoder:")) model_layout.addWidget(self.encoder_box) self.synthesizer_box = QComboBox() - model_layout.addWidget(QLabel("Synthesizer:")) - model_layout.addWidget(self.synthesizer_box) + if vc_mode: + self.extractor_box = QComboBox() + model_layout.addWidget(QLabel("Extractor:")) + model_layout.addWidget(self.extractor_box) + self.convertor_box = QComboBox() + model_layout.addWidget(QLabel("Convertor:")) + model_layout.addWidget(self.convertor_box) + else: + model_layout.addWidget(QLabel("Synthesizer:")) + model_layout.addWidget(self.synthesizer_box) self.vocoder_box = QComboBox() model_layout.addWidget(QLabel("Vocoder:")) model_layout.addWidget(self.vocoder_box) - - + #Replay & Save Audio i = 0 output_layout.addWidget(QLabel("Toolbox Output:"), i, 0) @@ -550,7 +587,7 @@ class UI(QDialog): ## Embed & spectrograms vis_layout.addStretch() - + # TODO: add spectrograms for source gridspec_kw = {"width_ratios": [1, 4]} fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0", gridspec_kw=gridspec_kw) @@ -571,16 +608,23 @@ class UI(QDialog): self.text_prompt = QPlainTextEdit(default_text) gen_layout.addWidget(self.text_prompt, stretch=1) - self.generate_button = QPushButton("Synthesize and vocode") - gen_layout.addWidget(self.generate_button) - - layout = QHBoxLayout() - self.synthesize_button = QPushButton("Synthesize only") - layout.addWidget(self.synthesize_button) + if vc_mode: + layout = QHBoxLayout() + self.convert_button = QPushButton("Extract and Convert") + layout.addWidget(self.convert_button) + gen_layout.addLayout(layout) + else: + self.generate_button = QPushButton("Synthesize and vocode") + gen_layout.addWidget(self.generate_button) + layout = QHBoxLayout() + self.synthesize_button = QPushButton("Synthesize only") + layout.addWidget(self.synthesize_button) + self.vocode_button = QPushButton("Vocode only") layout.addWidget(self.vocode_button) gen_layout.addLayout(layout) + layout_seed = QGridLayout() self.random_seed_checkbox = QCheckBox("Random seed:") self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.") @@ -648,7 +692,7 @@ class UI(QDialog): self.resize(max_size) ## Finalize the display - self.reset_interface() + self.reset_interface(vc_mode) self.show() def start(self): diff --git a/train.py b/train.py new file mode 100644 index 0000000..5a6a06c --- /dev/null +++ b/train.py @@ -0,0 +1,67 @@ +import sys +import torch +import argparse +import numpy as np +from utils.load_yaml import HpsYaml +from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + +# For reproducibility, comment these may speed up training +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +def main(): + # Arguments + parser = argparse.ArgumentParser(description= + 'Training PPG2Mel VC model.') + parser.add_argument('--config', type=str, + help='Path to experiment config, e.g., config/vc.yaml') + parser.add_argument('--name', default=None, type=str, help='Name for logging.') + parser.add_argument('--logdir', default='log/', type=str, + help='Logging path.', required=False) + parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str, + help='Checkpoint path.', required=False) + parser.add_argument('--outdir', default='result/', type=str, + help='Decode output path.', required=False) + parser.add_argument('--load', default=None, type=str, + help='Load pre-trained model (for training only)', required=False) + parser.add_argument('--warm_start', action='store_true', + help='Load model weights only, ignore specified layers.') + parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducable results.', required=False) + parser.add_argument('--njobs', default=8, type=int, + help='Number of threads for dataloader/decoding.', required=False) + parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') + parser.add_argument('--no-pin', action='store_true', + help='Disable pin-memory for dataloader') + parser.add_argument('--test', action='store_true', help='Test the model.') + parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') + parser.add_argument('--finetune', action='store_true', help='Finetune model') + parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') + parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') + parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') + + ### + + paras = parser.parse_args() + setattr(paras, 'gpu', not paras.cpu) + setattr(paras, 'pin_memory', not paras.no_pin) + setattr(paras, 'verbose', not paras.no_msg) + # Make the config dict dot visitable + config = HpsYaml(paras.config) + + np.random.seed(paras.seed) + torch.manual_seed(paras.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(paras.seed) + + print(">>> OneShot VC training ...") + mode = "train" + solver = Solver(config, paras, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/utils/audio_utils.py b/utils/audio_utils.py new file mode 100644 index 0000000..1dbeddb --- /dev/null +++ b/utils/audio_utils.py @@ -0,0 +1,60 @@ + +import torch +import torch.utils.data +from scipy.io.wavfile import read +from librosa.filters import mel as librosa_mel_fn + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + +def _dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _spectral_normalize_torch(magnitudes): + output = _dynamic_range_compression_torch(magnitudes) + return output + +mel_basis = {} +hann_window = {} + +def mel_spectrogram( + y, + n_fft, + num_mels, + sampling_rate, + hop_size, + win_size, + fmin, + fmax, + center=False, + output_energy=False, +): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) + mel_spec = _spectral_normalize_torch(mel_spec) + if output_energy: + energy = torch.norm(spec, dim=1) + return mel_spec, energy + else: + return mel_spec diff --git a/utils/data_load.py b/utils/data_load.py new file mode 100644 index 0000000..37723cf --- /dev/null +++ b/utils/data_load.py @@ -0,0 +1,214 @@ +import random +import numpy as np +import torch +from utils.f0_utils import get_cont_lf0 +import resampy +from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram +from librosa.util import normalize +import os + + +SAMPLE_RATE=16000 + +def read_fids(fid_list_f): + with open(fid_list_f, 'r') as f: + fids = [l.strip().split()[0] for l in f if l.strip()] + return fids + +class OneshotVcDataset(torch.utils.data.Dataset): + def __init__( + self, + meta_file: str, + vctk_ppg_dir: str, + libri_ppg_dir: str, + vctk_f0_dir: str, + libri_f0_dir: str, + vctk_wav_dir: str, + libri_wav_dir: str, + vctk_spk_dvec_dir: str, + libri_spk_dvec_dir: str, + min_max_norm_mel: bool = False, + mel_min: float = None, + mel_max: float = None, + ppg_file_ext: str = "ling_feat.npy", + f0_file_ext: str = "f0.npy", + wav_file_ext: str = "wav", + ): + self.fid_list = read_fids(meta_file) + self.vctk_ppg_dir = vctk_ppg_dir + self.libri_ppg_dir = libri_ppg_dir + self.vctk_f0_dir = vctk_f0_dir + self.libri_f0_dir = libri_f0_dir + self.vctk_wav_dir = vctk_wav_dir + self.libri_wav_dir = libri_wav_dir + self.vctk_spk_dvec_dir = vctk_spk_dvec_dir + self.libri_spk_dvec_dir = libri_spk_dvec_dir + + self.ppg_file_ext = ppg_file_ext + self.f0_file_ext = f0_file_ext + self.wav_file_ext = wav_file_ext + + self.min_max_norm_mel = min_max_norm_mel + if min_max_norm_mel: + print("[INFO] Min-Max normalize Melspec.") + assert mel_min is not None + assert mel_max is not None + self.mel_max = mel_max + self.mel_min = mel_min + + random.seed(1234) + random.shuffle(self.fid_list) + print(f'[INFO] Got {len(self.fid_list)} samples.') + + def __len__(self): + return len(self.fid_list) + + def get_spk_dvec(self, fid): + spk_name = fid + if spk_name.startswith("p"): + spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy" + else: + spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy" + return torch.from_numpy(np.load(spk_dvec_path)) + + def compute_mel(self, wav_path): + audio, sr = load_wav(wav_path) + if sr != SAMPLE_RATE: + audio = resampy.resample(audio, sr, SAMPLE_RATE) + audio = audio / MAX_WAV_VALUE + audio = normalize(audio) * 0.95 + audio = torch.FloatTensor(audio).unsqueeze(0) + melspec = mel_spectrogram( + audio, + n_fft=1024, + num_mels=80, + sampling_rate=SAMPLE_RATE, + hop_size=160, + win_size=1024, + fmin=80, + fmax=8000, + ) + return melspec.squeeze(0).numpy().T + + def bin_level_min_max_norm(self, melspec): + # frequency bin level min-max normalization to [-4, 4] + mel = (melspec - self.mel_min) / (self.mel_max - self.mel_min) * 8.0 - 4.0 + return np.clip(mel, -4., 4.) + + def __getitem__(self, index): + fid = self.fid_list[index] + + # 1. Load features + if fid.startswith("p"): + # vctk + sub = fid.split("_")[0] + ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}") + f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}") + mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}") + else: + # aidatatang + sub = fid[5:10] + ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}") + f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}") + mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}") + if self.min_max_norm_mel: + mel = self.bin_level_min_max_norm(mel) + + f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid) + spk_dvec = self.get_spk_dvec(fid) + + # 2. Convert f0 to continuous log-f0 and u/v flags + uv, cont_lf0 = get_cont_lf0(f0, 10.0, False) + # cont_lf0 = (cont_lf0 - np.amin(cont_lf0)) / (np.amax(cont_lf0) - np.amin(cont_lf0)) + # cont_lf0 = self.utt_mvn(cont_lf0) + lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1) + + # uv, cont_f0 = convert_continuous_f0(f0) + # cont_f0 = (cont_f0 - np.amin(cont_f0)) / (np.amax(cont_f0) - np.amin(cont_f0)) + # lf0_uv = np.concatenate([cont_f0[:, np.newaxis], uv[:, np.newaxis]], axis=1) + + # 3. Convert numpy array to torch.tensor + ppg = torch.from_numpy(ppg) + lf0_uv = torch.from_numpy(lf0_uv) + mel = torch.from_numpy(mel) + + return (ppg, lf0_uv, mel, spk_dvec, fid) + + def check_lengths(self, f0, ppg, mel, fid): + LEN_THRESH = 10 + assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \ + f"{abs(len(ppg) - len(f0))}: for file {fid}" + assert abs(len(mel) - len(f0)) <= LEN_THRESH, \ + f"{abs(len(mel) - len(f0))}: for file {fid}" + + def _adjust_lengths(self, f0, ppg, mel, fid): + self.check_lengths(f0, ppg, mel, fid) + min_len = min( + len(f0), + len(ppg), + len(mel), + ) + f0 = f0[:min_len] + ppg = ppg[:min_len] + mel = mel[:min_len] + return f0, ppg, mel + +class MultiSpkVcCollate(): + """Zero-pads model inputs and targets based on number of frames per step + """ + def __init__(self, n_frames_per_step=1, give_uttids=False, + f02ppg_length_ratio=1, use_spk_dvec=False): + self.n_frames_per_step = n_frames_per_step + self.give_uttids = give_uttids + self.f02ppg_length_ratio = f02ppg_length_ratio + self.use_spk_dvec = use_spk_dvec + + def __call__(self, batch): + batch_size = len(batch) + # Prepare different features + ppgs = [x[0] for x in batch] + lf0_uvs = [x[1] for x in batch] + mels = [x[2] for x in batch] + fids = [x[-1] for x in batch] + if len(batch[0]) == 5: + spk_ids = [x[3] for x in batch] + if self.use_spk_dvec: + # use d-vector + spk_ids = torch.stack(spk_ids).float() + else: + # use one-hot ids + spk_ids = torch.LongTensor(spk_ids) + # Pad features into chunk + ppg_lengths = [x.shape[0] for x in ppgs] + mel_lengths = [x.shape[0] for x in mels] + max_ppg_len = max(ppg_lengths) + max_mel_len = max(mel_lengths) + if max_mel_len % self.n_frames_per_step != 0: + max_mel_len += (self.n_frames_per_step - max_mel_len % self.n_frames_per_step) + ppg_dim = ppgs[0].shape[1] + mel_dim = mels[0].shape[1] + ppgs_padded = torch.FloatTensor(batch_size, max_ppg_len, ppg_dim).zero_() + mels_padded = torch.FloatTensor(batch_size, max_mel_len, mel_dim).zero_() + lf0_uvs_padded = torch.FloatTensor(batch_size, self.f02ppg_length_ratio * max_ppg_len, 2).zero_() + stop_tokens = torch.FloatTensor(batch_size, max_mel_len).zero_() + for i in range(batch_size): + cur_ppg_len = ppgs[i].shape[0] + cur_mel_len = mels[i].shape[0] + ppgs_padded[i, :cur_ppg_len, :] = ppgs[i] + lf0_uvs_padded[i, :self.f02ppg_length_ratio*cur_ppg_len, :] = lf0_uvs[i] + mels_padded[i, :cur_mel_len, :] = mels[i] + stop_tokens[i, cur_ppg_len-self.n_frames_per_step:] = 1 + if len(batch[0]) == 5: + ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \ + torch.LongTensor(mel_lengths), spk_ids, stop_tokens) + if self.give_uttids: + return ret_tup + (fids, ) + else: + return ret_tup + else: + ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \ + torch.LongTensor(mel_lengths), stop_tokens) + if self.give_uttids: + return ret_tup + (fids, ) + else: + return ret_tup diff --git a/utils/f0_utils.py b/utils/f0_utils.py new file mode 100644 index 0000000..6bc25a8 --- /dev/null +++ b/utils/f0_utils.py @@ -0,0 +1,124 @@ +import logging +import numpy as np +import pyworld +from scipy.interpolate import interp1d +from scipy.signal import firwin, get_window, lfilter + +def compute_mean_std(lf0): + nonzero_indices = np.nonzero(lf0) + mean = np.mean(lf0[nonzero_indices]) + std = np.std(lf0[nonzero_indices]) + return mean, std + + +def compute_f0(wav, sr=16000, frame_period=10.0): + """Compute f0 from wav using pyworld harvest algorithm.""" + wav = wav.astype(np.float64) + f0, _ = pyworld.harvest( + wav, sr, frame_period=frame_period, f0_floor=80.0, f0_ceil=600.0) + return f0.astype(np.float32) + +def f02lf0(f0): + lf0 = f0.copy() + nonzero_indices = np.nonzero(f0) + lf0[nonzero_indices] = np.log(f0[nonzero_indices]) + return lf0 + +def get_converted_lf0uv( + wav, + lf0_mean_trg, + lf0_std_trg, + convert=True, +): + f0_src = compute_f0(wav) + if not convert: + uv, cont_lf0 = get_cont_lf0(f0_src) + lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1) + return lf0_uv + + lf0_src = f02lf0(f0_src) + lf0_mean_src, lf0_std_src = compute_mean_std(lf0_src) + + lf0_vc = lf0_src.copy() + lf0_vc[lf0_src > 0.0] = (lf0_src[lf0_src > 0.0] - lf0_mean_src) / lf0_std_src * lf0_std_trg + lf0_mean_trg + f0_vc = lf0_vc.copy() + f0_vc[lf0_src > 0.0] = np.exp(lf0_vc[lf0_src > 0.0]) + + uv, cont_lf0_vc = get_cont_lf0(f0_vc) + lf0_uv = np.concatenate([cont_lf0_vc[:, np.newaxis], uv[:, np.newaxis]], axis=1) + return lf0_uv + +def low_pass_filter(x, fs, cutoff=70, padding=True): + """FUNCTION TO APPLY LOW PASS FILTER + + Args: + x (ndarray): Waveform sequence + fs (int): Sampling frequency + cutoff (float): Cutoff frequency of low pass filter + + Return: + (ndarray): Low pass filtered waveform sequence + """ + + nyquist = fs // 2 + norm_cutoff = cutoff / nyquist + + # low cut filter + numtaps = 255 + fil = firwin(numtaps, norm_cutoff) + x_pad = np.pad(x, (numtaps, numtaps), 'edge') + lpf_x = lfilter(fil, 1, x_pad) + lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2] + + return lpf_x + + +def convert_continuos_f0(f0): + """CONVERT F0 TO CONTINUOUS F0 + + Args: + f0 (ndarray): original f0 sequence with the shape (T) + + Return: + (ndarray): continuous f0 with the shape (T) + """ + # get uv information as binary + uv = np.float32(f0 != 0) + + # get start and end of f0 + if (f0 == 0).all(): + logging.warn("all of the f0 values are 0.") + return uv, f0 + start_f0 = f0[f0 != 0][0] + end_f0 = f0[f0 != 0][-1] + + # padding start and end of f0 sequence + start_idx = np.where(f0 == start_f0)[0][0] + end_idx = np.where(f0 == end_f0)[0][-1] + f0[:start_idx] = start_f0 + f0[end_idx:] = end_f0 + + # get non-zero frame index + nz_frames = np.where(f0 != 0)[0] + + # perform linear interpolation + f = interp1d(nz_frames, f0[nz_frames]) + cont_f0 = f(np.arange(0, f0.shape[0])) + + return uv, cont_f0 + + +def get_cont_lf0(f0, frame_period=10.0, lpf=False): + uv, cont_f0 = convert_continuos_f0(f0) + if lpf: + cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20) + cont_lf0_lpf = cont_f0_lpf.copy() + nonzero_indices = np.nonzero(cont_lf0_lpf) + cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices]) + # cont_lf0_lpf = np.log(cont_f0_lpf) + return uv, cont_lf0_lpf + else: + nonzero_indices = np.nonzero(cont_f0) + cont_lf0 = cont_f0.copy() + cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0]) + return uv, cont_lf0 diff --git a/utils/load_yaml.py b/utils/load_yaml.py new file mode 100644 index 0000000..5792ff4 --- /dev/null +++ b/utils/load_yaml.py @@ -0,0 +1,58 @@ +import yaml + + +def load_hparams(filename): + stream = open(filename, 'r') + docs = yaml.safe_load_all(stream) + hparams_dict = dict() + for doc in docs: + for k, v in doc.items(): + hparams_dict[k] = v + return hparams_dict + +def merge_dict(user, default): + if isinstance(user, dict) and isinstance(default, dict): + for k, v in default.items(): + if k not in user: + user[k] = v + else: + user[k] = merge_dict(user[k], v) + return user + +class Dotdict(dict): + """ + a dictionary that supports dot notation + as well as dictionary access notation + usage: d = DotDict() or d = DotDict({'val1':'first'}) + set attributes: d.val2 = 'second' or d['val2'] = 'second' + get attributes: d.val2 or d['val2'] + """ + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __init__(self, dct=None): + dct = dict() if not dct else dct + for key, value in dct.items(): + if hasattr(value, 'keys'): + value = Dotdict(value) + self[key] = value + +class HpsYaml(Dotdict): + def __init__(self, yaml_file): + super(Dotdict, self).__init__() + hps = load_hparams(yaml_file) + hp_dict = Dotdict(hps) + for k, v in hp_dict.items(): + setattr(self, k, v) + + __getattr__ = Dotdict.__getitem__ + __setattr__ = Dotdict.__setitem__ + __delattr__ = Dotdict.__delitem__ + + + + + + + diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..5227538 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,44 @@ +import matplotlib +matplotlib.use('Agg') +import time + +class Timer(): + ''' Timer for recording training time distribution. ''' + def __init__(self): + self.prev_t = time.time() + self.clear() + + def set(self): + self.prev_t = time.time() + + def cnt(self, mode): + self.time_table[mode] += time.time()-self.prev_t + self.set() + if mode == 'bw': + self.click += 1 + + def show(self): + total_time = sum(self.time_table.values()) + self.time_table['avg'] = total_time/self.click + self.time_table['rd'] = 100*self.time_table['rd']/total_time + self.time_table['fw'] = 100*self.time_table['fw']/total_time + self.time_table['bw'] = 100*self.time_table['bw']/total_time + msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format( + **self.time_table) + self.clear() + return msg + + def clear(self): + self.time_table = {'rd': 0, 'fw': 0, 'bw': 0} + self.click = 0 + +# Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168 + +def human_format(num): + magnitude = 0 + while num >= 1000: + magnitude += 1 + num /= 1000.0 + # add more suffixes if you need them + return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude]) + diff --git a/vocoder/hifigan/inference.py b/vocoder/hifigan/inference.py index 0912726..1475146 100644 --- a/vocoder/hifigan/inference.py +++ b/vocoder/hifigan/inference.py @@ -3,14 +3,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera import os import json import torch -from scipy.io.wavfile import write from vocoder.hifigan.env import AttrDict -from vocoder.hifigan.meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav from vocoder.hifigan.models import Generator -import soundfile as sf - generator = None # type: Generator +output_sample_rate = None _device = None @@ -22,16 +19,17 @@ def load_checkpoint(filepath, device): return checkpoint_dict -def load_model(weights_fpath, verbose=True): - global generator, _device +def load_model(weights_fpath, config_fpath="./vocoder/saved_models/24k/config.json", verbose=True): + global generator, _device, output_sample_rate if verbose: print("Building hifigan") - with open("./vocoder/hifigan/config_16k_.json") as f: + with open(config_fpath) as f: data = f.read() json_config = json.loads(data) h = AttrDict(json_config) + output_sample_rate = h.sampling_rate torch.manual_seed(h.seed) if torch.cuda.is_available(): @@ -66,5 +64,5 @@ def infer_waveform(mel, progress_callback=None): audio = y_g_hat.squeeze() audio = audio.cpu().numpy() - return audio + return audio, output_sample_rate diff --git a/vocoder/hifigan/models.py b/vocoder/hifigan/models.py index 9caf382..c352e19 100644 --- a/vocoder/hifigan/models.py +++ b/vocoder/hifigan/models.py @@ -71,6 +71,24 @@ class ResBlock2(torch.nn.Module): for l in self.convs: remove_weight_norm(l) +class InterpolationBlock(torch.nn.Module): + def __init__(self, scale_factor, mode='nearest', align_corners=None, downsample=False): + super(InterpolationBlock, self).__init__() + self.downsample = downsample + self.scale_factor = scale_factor + self.mode = mode + self.align_corners = align_corners + + def forward(self, x): + outputs = torch.nn.functional.interpolate( + x, + size=x.shape[-1] * self.scale_factor \ + if not self.downsample else x.shape[-1] // self.scale_factor, + mode=self.mode, + align_corners=self.align_corners, + recompute_scale_factor=False + ) + return outputs class Generator(torch.nn.Module): def __init__(self, h): @@ -82,14 +100,27 @@ class Generator(torch.nn.Module): resblock = ResBlock1 if h.resblock == '1' else ResBlock2 self.ups = nn.ModuleList() - for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): -# self.ups.append(weight_norm( -# ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), -# k, u, padding=(k-u)//2))) - self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i), - h.upsample_initial_channel//(2**(i+1)), - k, u, padding=(u//2 + u%2), output_padding=u%2))) - +# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): +# # self.ups.append(weight_norm( +# # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)), +# # k, u, padding=(k-u)//2))) + if h.sampling_rate == 24000: + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + torch.nn.Sequential( + InterpolationBlock(u), + weight_norm(torch.nn.Conv1d( + h.upsample_initial_channel//(2**i), + h.upsample_initial_channel//(2**(i+1)), + k, padding=(k-1)//2, + )) + ) + ) + else: + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i), + h.upsample_initial_channel//(2**(i+1)), + k, u, padding=(u//2 + u%2), output_padding=u%2))) self.resblocks = nn.ModuleList() for i in range(len(self.ups)): ch = h.upsample_initial_channel//(2**(i+1)) @@ -121,7 +152,10 @@ class Generator(torch.nn.Module): def remove_weight_norm(self): print('Removing weight norm...') for l in self.ups: - remove_weight_norm(l) + if self.h.sampling_rate == 24000: + remove_weight_norm(l[-1]) + else: + remove_weight_norm(l) for l in self.resblocks: l.remove_weight_norm() remove_weight_norm(self.conv_pre) diff --git a/vocoder/wavernn/inference.py b/vocoder/wavernn/inference.py index 285ed6d..40cd305 100644 --- a/vocoder/wavernn/inference.py +++ b/vocoder/wavernn/inference.py @@ -61,4 +61,4 @@ def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800, mel = mel / hp.mel_max_abs_value mel = torch.from_numpy(mel[None, ...]) wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback) - return wav + return wav, hp.sample_rate diff --git a/web/__init__.py b/web/__init__.py index 4afc920..2e38817 100644 --- a/web/__init__.py +++ b/web/__init__.py @@ -107,14 +107,15 @@ def webApp(): embeds = [embed] * len(texts) specs = current_synt.synthesize_spectrograms(texts, embeds) spec = np.concatenate(specs, axis=1) + sample_rate = Synthesizer.sample_rate if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN": wav = rnn_vocoder.infer_waveform(spec) else: - wav = gan_vocoder.infer_waveform(spec) + wav, sample_rate = gan_vocoder.infer_waveform(spec) # Return cooked wav out = io.BytesIO() - write(out, Synthesizer.sample_rate, wav.astype(np.float32)) + write(out, sample_rate, wav.astype(np.float32)) return Response(out, mimetype="audio/wav") @app.route('/', methods=['GET'])