mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
Init ppg extractor and ppg2mel (#375)
* Init ppg extractor and ppg2mel * add preprocess and training * FIx known issues * Update __init__.py Allow to gen audio * Fix length issue * Fix bug of preparing fid * Fix sample issues * Add UI usage of PPG-vc
This commit is contained in:
parent
ad22997614
commit
b617a87ee4
9
.gitignore
vendored
9
.gitignore
vendored
|
@ -15,9 +15,8 @@
|
|||
*.toc
|
||||
*.wav
|
||||
*.sh
|
||||
synthesizer/saved_models/*
|
||||
vocoder/saved_models/*
|
||||
encoder/saved_models/*
|
||||
cp_hifigan/*
|
||||
!vocoder/saved_models/pretrained/*
|
||||
*/saved_models
|
||||
!vocoder/saved_models/pretrained/**
|
||||
!encoder/saved_models/pretrained.pt
|
||||
wavs
|
||||
log
|
18
.vscode/launch.json
vendored
18
.vscode/launch.json
vendored
|
@ -35,6 +35,14 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Demo Box VC",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata","-vc"]
|
||||
},
|
||||
{
|
||||
"name": "Python: Synth Train",
|
||||
"type": "python",
|
||||
|
@ -43,5 +51,15 @@
|
|||
"console": "integratedTerminal",
|
||||
"args": ["my_run", "..\\"]
|
||||
},
|
||||
{
|
||||
"name": "Python: PPG Convert",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "run.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-c", ".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2.yaml",
|
||||
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
|
||||
]
|
||||
},
|
||||
]
|
||||
}
|
||||
|
|
|
@ -15,12 +15,18 @@ if __name__ == '__main__':
|
|||
parser.add_argument("-d", "--datasets_root", type=Path, help= \
|
||||
"Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
|
||||
"supported datasets.", default=None)
|
||||
parser.add_argument("-vc", "--vc_mode", action="store_true",
|
||||
help="Voice Conversion Mode(PPG based)")
|
||||
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
||||
help="Directory containing saved encoder models")
|
||||
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
||||
help="Directory containing saved synthesizer models")
|
||||
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
||||
help="Directory containing saved vocoder models")
|
||||
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models",
|
||||
help="Directory containing saved extrator models")
|
||||
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models",
|
||||
help="Directory containing saved convert models")
|
||||
parser.add_argument("--cpu", action="store_true", help=\
|
||||
"If True, processing is done on CPU, even when a GPU is available.")
|
||||
parser.add_argument("--seed", type=int, default=None, help=\
|
||||
|
|
|
@ -34,7 +34,15 @@ def load_model(weights_fpath: Path, device=None):
|
|||
_model.load_state_dict(checkpoint["model_state"])
|
||||
_model.eval()
|
||||
print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
|
||||
return _model
|
||||
|
||||
def set_model(model, device=None):
|
||||
global _model, _device
|
||||
_model = model
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
_device = device
|
||||
_model.to(device)
|
||||
|
||||
def is_loaded():
|
||||
return _model is not None
|
||||
|
@ -57,7 +65,7 @@ def embed_frames_batch(frames_batch):
|
|||
|
||||
|
||||
def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
|
||||
min_pad_coverage=0.75, overlap=0.5):
|
||||
min_pad_coverage=0.75, overlap=0.5, rate=None):
|
||||
"""
|
||||
Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
|
||||
partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
|
||||
|
@ -85,10 +93,19 @@ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_fram
|
|||
assert 0 <= overlap < 1
|
||||
assert 0 < min_pad_coverage <= 1
|
||||
|
||||
if rate != None:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = int(np.round((sampling_rate / rate) / samples_per_frame))
|
||||
else:
|
||||
samples_per_frame = int((sampling_rate * mel_window_step / 1000))
|
||||
n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
|
||||
frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
|
||||
|
||||
assert 0 < frame_step, "The rate is too high"
|
||||
assert frame_step <= partials_n_frames, "The rate is too low, it should be %f at least" % \
|
||||
(sampling_rate / (samples_per_frame * partials_n_frames))
|
||||
|
||||
# Compute the slices
|
||||
wav_slices, mel_slices = [], []
|
||||
steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
|
||||
|
|
206
ppg2mel/__init__.py
Normal file
206
ppg2mel/__init__.py
Normal file
|
@ -0,0 +1,206 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2020 Songxiang Liu
|
||||
# Apache 2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .utils.abs_model import AbsMelDecoder
|
||||
from .rnn_decoder_mol import Decoder
|
||||
from .utils.cnn_postnet import Postnet
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
class MelDecoderMOLv2(AbsMelDecoder):
|
||||
"""Use an encoder to preprocess ppg."""
|
||||
def __init__(
|
||||
self,
|
||||
num_speakers: int,
|
||||
spk_embed_dim: int,
|
||||
bottle_neck_feature_dim: int,
|
||||
encoder_dim: int = 256,
|
||||
encoder_downsample_rates: List = [2, 2],
|
||||
attention_rnn_dim: int = 512,
|
||||
decoder_rnn_dim: int = 512,
|
||||
num_decoder_rnn_layer: int = 1,
|
||||
concat_context_to_last: bool = True,
|
||||
prenet_dims: List = [256, 128],
|
||||
num_mixtures: int = 5,
|
||||
frames_per_step: int = 2,
|
||||
mask_padding: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.mask_padding = mask_padding
|
||||
self.bottle_neck_feature_dim = bottle_neck_feature_dim
|
||||
self.num_mels = 80
|
||||
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
|
||||
self.frames_per_step = frames_per_step
|
||||
self.use_spk_dvec = True
|
||||
|
||||
input_dim = bottle_neck_feature_dim
|
||||
|
||||
# Downsampling convolution
|
||||
self.bnf_prenet = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
decoder_enc_dim = encoder_dim
|
||||
self.pitch_convs = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[0],
|
||||
stride=encoder_downsample_rates[0],
|
||||
padding=encoder_downsample_rates[0]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
torch.nn.Conv1d(
|
||||
encoder_dim, encoder_dim,
|
||||
kernel_size=2*encoder_downsample_rates[1],
|
||||
stride=encoder_downsample_rates[1],
|
||||
padding=encoder_downsample_rates[1]//2,
|
||||
),
|
||||
torch.nn.LeakyReLU(0.1),
|
||||
|
||||
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
||||
)
|
||||
|
||||
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
|
||||
|
||||
# Decoder
|
||||
self.decoder = Decoder(
|
||||
enc_dim=decoder_enc_dim,
|
||||
num_mels=self.num_mels,
|
||||
frames_per_step=frames_per_step,
|
||||
attention_rnn_dim=attention_rnn_dim,
|
||||
decoder_rnn_dim=decoder_rnn_dim,
|
||||
num_decoder_rnn_layer=num_decoder_rnn_layer,
|
||||
prenet_dims=prenet_dims,
|
||||
num_mixtures=num_mixtures,
|
||||
use_stop_tokens=True,
|
||||
concat_context_to_last=concat_context_to_last,
|
||||
encoder_down_factor=self.encoder_down_factor,
|
||||
)
|
||||
|
||||
# Mel-Spec Postnet: some residual CNN layers
|
||||
self.postnet = Postnet()
|
||||
|
||||
def parse_output(self, outputs, output_lengths=None):
|
||||
if self.mask_padding and output_lengths is not None:
|
||||
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
|
||||
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
|
||||
outputs[0].data.masked_fill_(mask, 0.0)
|
||||
outputs[1].data.masked_fill_(mask, 0.0)
|
||||
return outputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
output_att_ws: bool = False,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(
|
||||
bottle_neck_features.transpose(1, 2)
|
||||
).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
decoder_inputs = self.reduce_proj(decoder_inputs)
|
||||
|
||||
# (B, num_mels, T_dec)
|
||||
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
||||
mel_outputs, predicted_stop, alignments = self.decoder(
|
||||
decoder_inputs, speech, T_dec)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
if output_att_ws:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
|
||||
else:
|
||||
return self.parse_output(
|
||||
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
|
||||
|
||||
# return mel_outputs, mel_outputs_postnet
|
||||
|
||||
def inference(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
):
|
||||
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
|
||||
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
||||
decoder_inputs = decoder_inputs + logf0_uv
|
||||
|
||||
assert spembs is not None
|
||||
spk_embeds = F.normalize(
|
||||
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
||||
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
||||
bottle_neck_features = self.reduce_proj(bottle_neck_features)
|
||||
|
||||
## Decoder
|
||||
if bottle_neck_features.size(0) > 1:
|
||||
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
|
||||
else:
|
||||
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
|
||||
## Post-processing
|
||||
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
||||
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
||||
# outputs = mel_outputs_postnet[0]
|
||||
|
||||
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
||||
|
||||
def load_model(train_config, model_file, device=None):
|
||||
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
model_config = HpsYaml(train_config)
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
ckpt = torch.load(model_file, map_location=device)
|
||||
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||
ppg2mel_model.eval()
|
||||
return ppg2mel_model
|
112
ppg2mel/preprocess.py
Normal file
112
ppg2mel/preprocess.py
Normal file
|
@ -0,0 +1,112 @@
|
|||
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from pathlib import Path
|
||||
import soundfile
|
||||
import resampy
|
||||
|
||||
from ppg_extractor import load_model
|
||||
import encoder.inference as Encoder
|
||||
from encoder.audio import preprocess_wav
|
||||
from encoder import audio
|
||||
from utils.f0_utils import compute_f0
|
||||
|
||||
from torch.multiprocessing import Pool, cpu_count
|
||||
from functools import partial
|
||||
|
||||
SAMPLE_RATE=16000
|
||||
|
||||
def _compute_bnf(
|
||||
wav: any,
|
||||
output_fpath: str,
|
||||
device: torch.device,
|
||||
ppg_model_local: any,
|
||||
):
|
||||
"""
|
||||
Compute CTC-Attention Seq2seq ASR encoder bottle-neck features (BNF).
|
||||
"""
|
||||
ppg_model_local.to(device)
|
||||
wav_tensor = torch.from_numpy(wav).float().to(device).unsqueeze(0)
|
||||
wav_length = torch.LongTensor([wav.shape[0]]).to(device)
|
||||
with torch.no_grad():
|
||||
bnf = ppg_model_local(wav_tensor, wav_length)
|
||||
bnf_npy = bnf.squeeze(0).cpu().numpy()
|
||||
np.save(output_fpath, bnf_npy, allow_pickle=False)
|
||||
return bnf_npy, len(bnf_npy)
|
||||
|
||||
def _compute_f0_from_wav(wav, output_fpath):
|
||||
"""Compute merged f0 values."""
|
||||
f0 = compute_f0(wav, SAMPLE_RATE)
|
||||
np.save(output_fpath, f0, allow_pickle=False)
|
||||
return f0, len(f0)
|
||||
|
||||
def _compute_spkEmbed(wav, output_fpath, encoder_model_local, device):
|
||||
Encoder.set_model(encoder_model_local)
|
||||
# Compute where to split the utterance into partials and pad if necessary
|
||||
wave_slices, mel_slices = Encoder.compute_partial_slices(len(wav), rate=1.3, min_pad_coverage=0.75)
|
||||
max_wave_length = wave_slices[-1].stop
|
||||
if max_wave_length >= len(wav):
|
||||
wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
|
||||
|
||||
# Split the utterance into partials
|
||||
frames = audio.wav_to_mel_spectrogram(wav)
|
||||
frames_batch = np.array([frames[s] for s in mel_slices])
|
||||
partial_embeds = Encoder.embed_frames_batch(frames_batch)
|
||||
|
||||
# Compute the utterance embedding from the partial embeddings
|
||||
raw_embed = np.mean(partial_embeds, axis=0)
|
||||
embed = raw_embed / np.linalg.norm(raw_embed, 2)
|
||||
|
||||
np.save(output_fpath, embed, allow_pickle=False)
|
||||
return embed, len(embed)
|
||||
|
||||
def preprocess_one(wav_path, out_dir, device, ppg_model_local, encoder_model_local):
|
||||
# wav = preprocess_wav(wav_path)
|
||||
# try:
|
||||
wav, sr = soundfile.read(wav_path)
|
||||
if len(wav) < sr:
|
||||
return None, sr, len(wav)
|
||||
if sr != SAMPLE_RATE:
|
||||
wav = resampy.resample(wav, sr, SAMPLE_RATE)
|
||||
sr = SAMPLE_RATE
|
||||
utt_id = os.path.basename(wav_path).rstrip(".wav")
|
||||
|
||||
_, length_bnf = _compute_bnf(output_fpath=f"{out_dir}/bnf/{utt_id}.ling_feat.npy", wav=wav, device=device, ppg_model_local=ppg_model_local)
|
||||
_, length_f0 = _compute_f0_from_wav(output_fpath=f"{out_dir}/f0/{utt_id}.f0.npy", wav=wav)
|
||||
_, length_embed = _compute_spkEmbed(output_fpath=f"{out_dir}/embed/{utt_id}.npy", device=device, encoder_model_local=encoder_model_local, wav=wav)
|
||||
|
||||
def preprocess_dataset(datasets_root, dataset, out_dir, n_processes, ppg_encoder_model_fpath, speaker_encoder_model):
|
||||
# Glob wav files
|
||||
wav_file_list = sorted(Path(f"{datasets_root}/{dataset}").glob("**/*.wav"))
|
||||
print(f"Globbed {len(wav_file_list)} wav files.")
|
||||
|
||||
out_dir.joinpath("bnf").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("f0").mkdir(exist_ok=True, parents=True)
|
||||
out_dir.joinpath("embed").mkdir(exist_ok=True, parents=True)
|
||||
ppg_model_local = load_model(ppg_encoder_model_fpath, "cpu")
|
||||
encoder_model_local = Encoder.load_model(speaker_encoder_model, "cpu")
|
||||
if n_processes is None:
|
||||
n_processes = cpu_count()
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
func = partial(preprocess_one, out_dir=out_dir, ppg_model_local=ppg_model_local, encoder_model_local=encoder_model_local, device=device)
|
||||
job = Pool(n_processes).imap(func, wav_file_list)
|
||||
list(tqdm(job, "Preprocessing", len(wav_file_list), unit="wav"))
|
||||
|
||||
# finish processing and mark
|
||||
t_fid_file = out_dir.joinpath("train_fidlist.txt").open("w", encoding="utf-8")
|
||||
d_fid_file = out_dir.joinpath("dev_fidlist.txt").open("w", encoding="utf-8")
|
||||
e_fid_file = out_dir.joinpath("eval_fidlist.txt").open("w", encoding="utf-8")
|
||||
for file in sorted(out_dir.joinpath("f0").glob("*.npy")):
|
||||
id = os.path.basename(file).split(".f0.npy")[0]
|
||||
if id.endswith("01"):
|
||||
d_fid_file.write(id + "\n")
|
||||
elif id.endswith("09"):
|
||||
e_fid_file.write(id + "\n")
|
||||
else:
|
||||
t_fid_file.write(id + "\n")
|
||||
t_fid_file.close()
|
||||
d_fid_file.close()
|
||||
e_fid_file.close()
|
374
ppg2mel/rnn_decoder_mol.py
Normal file
374
ppg2mel/rnn_decoder_mol.py
Normal file
|
@ -0,0 +1,374 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from .utils.mol_attention import MOLAttention
|
||||
from .utils.basic_layers import Linear
|
||||
from .utils.vc_utils import get_mask_from_lengths
|
||||
|
||||
|
||||
class DecoderPrenet(nn.Module):
|
||||
def __init__(self, in_dim, sizes):
|
||||
super().__init__()
|
||||
in_sizes = [in_dim] + sizes[:-1]
|
||||
self.layers = nn.ModuleList(
|
||||
[Linear(in_size, out_size, bias=False)
|
||||
for (in_size, out_size) in zip(in_sizes, sizes)])
|
||||
|
||||
def forward(self, x):
|
||||
for linear in self.layers:
|
||||
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Mixture of Logistic (MoL) attention-based RNN Decoder."""
|
||||
def __init__(
|
||||
self,
|
||||
enc_dim,
|
||||
num_mels,
|
||||
frames_per_step,
|
||||
attention_rnn_dim,
|
||||
decoder_rnn_dim,
|
||||
prenet_dims,
|
||||
num_mixtures,
|
||||
encoder_down_factor=1,
|
||||
num_decoder_rnn_layer=1,
|
||||
use_stop_tokens=False,
|
||||
concat_context_to_last=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.enc_dim = enc_dim
|
||||
self.encoder_down_factor = encoder_down_factor
|
||||
self.num_mels = num_mels
|
||||
self.frames_per_step = frames_per_step
|
||||
self.attention_rnn_dim = attention_rnn_dim
|
||||
self.decoder_rnn_dim = decoder_rnn_dim
|
||||
self.prenet_dims = prenet_dims
|
||||
self.use_stop_tokens = use_stop_tokens
|
||||
self.num_decoder_rnn_layer = num_decoder_rnn_layer
|
||||
self.concat_context_to_last = concat_context_to_last
|
||||
|
||||
# Mel prenet
|
||||
self.prenet = DecoderPrenet(num_mels, prenet_dims)
|
||||
self.prenet_pitch = DecoderPrenet(num_mels, prenet_dims)
|
||||
|
||||
# Attention RNN
|
||||
self.attention_rnn = nn.LSTMCell(
|
||||
prenet_dims[-1] + enc_dim,
|
||||
attention_rnn_dim
|
||||
)
|
||||
|
||||
# Attention
|
||||
self.attention_layer = MOLAttention(
|
||||
attention_rnn_dim,
|
||||
r=frames_per_step/encoder_down_factor,
|
||||
M=num_mixtures,
|
||||
)
|
||||
|
||||
# Decoder RNN
|
||||
self.decoder_rnn_layers = nn.ModuleList()
|
||||
for i in range(num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
enc_dim + attention_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
else:
|
||||
self.decoder_rnn_layers.append(
|
||||
nn.LSTMCell(
|
||||
decoder_rnn_dim,
|
||||
decoder_rnn_dim))
|
||||
# self.decoder_rnn = nn.LSTMCell(
|
||||
# 2 * enc_dim + attention_rnn_dim,
|
||||
# decoder_rnn_dim
|
||||
# )
|
||||
if concat_context_to_last:
|
||||
self.linear_projection = Linear(
|
||||
enc_dim + decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
else:
|
||||
self.linear_projection = Linear(
|
||||
decoder_rnn_dim,
|
||||
num_mels * frames_per_step
|
||||
)
|
||||
|
||||
|
||||
# Stop-token layer
|
||||
if self.use_stop_tokens:
|
||||
if concat_context_to_last:
|
||||
self.stop_layer = Linear(
|
||||
enc_dim + decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
else:
|
||||
self.stop_layer = Linear(
|
||||
decoder_rnn_dim, 1, bias=True, w_init_gain="sigmoid"
|
||||
)
|
||||
|
||||
|
||||
def get_go_frame(self, memory):
|
||||
B = memory.size(0)
|
||||
go_frame = torch.zeros((B, self.num_mels), dtype=torch.float,
|
||||
device=memory.device)
|
||||
return go_frame
|
||||
|
||||
def initialize_decoder_states(self, memory, mask):
|
||||
device = next(self.parameters()).device
|
||||
B = memory.size(0)
|
||||
|
||||
# attention rnn states
|
||||
self.attention_hidden = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
self.attention_cell = torch.zeros(
|
||||
(B, self.attention_rnn_dim), device=device)
|
||||
|
||||
# decoder rnn states
|
||||
self.decoder_hiddens = []
|
||||
self.decoder_cells = []
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
self.decoder_hiddens.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
self.decoder_cells.append(
|
||||
torch.zeros((B, self.decoder_rnn_dim),
|
||||
device=device)
|
||||
)
|
||||
# self.decoder_hidden = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
# self.decoder_cell = torch.zeros(
|
||||
# (B, self.decoder_rnn_dim), device=device)
|
||||
|
||||
self.attention_context = torch.zeros(
|
||||
(B, self.enc_dim), device=device)
|
||||
|
||||
self.memory = memory
|
||||
# self.processed_memory = self.attention_layer.memory_layer(memory)
|
||||
self.mask = mask
|
||||
|
||||
def parse_decoder_inputs(self, decoder_inputs):
|
||||
"""Prepare decoder inputs, i.e. gt mel
|
||||
Args:
|
||||
decoder_inputs:(B, T_out, n_mel_channels) inputs used for teacher-forced training.
|
||||
"""
|
||||
decoder_inputs = decoder_inputs.reshape(
|
||||
decoder_inputs.size(0),
|
||||
int(decoder_inputs.size(1)/self.frames_per_step), -1)
|
||||
# (B, T_out//r, r*num_mels) -> (T_out//r, B, r*num_mels)
|
||||
decoder_inputs = decoder_inputs.transpose(0, 1)
|
||||
# (T_out//r, B, num_mels)
|
||||
decoder_inputs = decoder_inputs[:,:,-self.num_mels:]
|
||||
return decoder_inputs
|
||||
|
||||
def parse_decoder_outputs(self, mel_outputs, alignments, stop_outputs):
|
||||
""" Prepares decoder outputs for output
|
||||
Args:
|
||||
mel_outputs:
|
||||
alignments:
|
||||
"""
|
||||
# (T_out//r, B, T_enc) -> (B, T_out//r, T_enc)
|
||||
alignments = torch.stack(alignments).transpose(0, 1)
|
||||
# (T_out//r, B) -> (B, T_out//r)
|
||||
if stop_outputs is not None:
|
||||
if alignments.size(0) == 1:
|
||||
stop_outputs = torch.stack(stop_outputs).unsqueeze(0)
|
||||
else:
|
||||
stop_outputs = torch.stack(stop_outputs).transpose(0, 1)
|
||||
stop_outputs = stop_outputs.contiguous()
|
||||
# (T_out//r, B, num_mels*r) -> (B, T_out//r, num_mels*r)
|
||||
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
||||
# decouple frames per step
|
||||
# (B, T_out, num_mels)
|
||||
mel_outputs = mel_outputs.view(
|
||||
mel_outputs.size(0), -1, self.num_mels)
|
||||
return mel_outputs, alignments, stop_outputs
|
||||
|
||||
def attend(self, decoder_input):
|
||||
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
||||
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
||||
cell_input, (self.attention_hidden, self.attention_cell))
|
||||
self.attention_context, attention_weights = self.attention_layer(
|
||||
self.attention_hidden, self.memory, None, self.mask)
|
||||
|
||||
decoder_rnn_input = torch.cat(
|
||||
(self.attention_hidden, self.attention_context), -1)
|
||||
|
||||
return decoder_rnn_input, self.attention_context, attention_weights
|
||||
|
||||
def decode(self, decoder_input):
|
||||
for i in range(self.num_decoder_rnn_layer):
|
||||
if i == 0:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
decoder_input, (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
else:
|
||||
self.decoder_hiddens[i], self.decoder_cells[i] = self.decoder_rnn_layers[i](
|
||||
self.decoder_hiddens[i-1], (self.decoder_hiddens[i], self.decoder_cells[i]))
|
||||
return self.decoder_hiddens[-1]
|
||||
|
||||
def forward(self, memory, mel_inputs, memory_lengths):
|
||||
""" Decoder forward pass for training
|
||||
Args:
|
||||
memory: (B, T_enc, enc_dim) Encoder outputs
|
||||
decoder_inputs: (B, T, num_mels) Decoder inputs for teacher forcing.
|
||||
memory_lengths: (B, ) Encoder output lengths for attention masking.
|
||||
Returns:
|
||||
mel_outputs: (B, T, num_mels) mel outputs from the decoder
|
||||
alignments: (B, T//r, T_enc) attention weights.
|
||||
"""
|
||||
# [1, B, num_mels]
|
||||
go_frame = self.get_go_frame(memory).unsqueeze(0)
|
||||
# [T//r, B, num_mels]
|
||||
mel_inputs = self.parse_decoder_inputs(mel_inputs)
|
||||
# [T//r + 1, B, num_mels]
|
||||
mel_inputs = torch.cat((go_frame, mel_inputs), dim=0)
|
||||
# [T//r + 1, B, prenet_dim]
|
||||
decoder_inputs = self.prenet(mel_inputs)
|
||||
# decoder_inputs_pitch = self.prenet_pitch(decoder_inputs__)
|
||||
|
||||
self.initialize_decoder_states(
|
||||
memory, mask=~get_mask_from_lengths(memory_lengths),
|
||||
)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
# self.attention_layer_pitch.init_states(memory_pitch)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
if self.use_stop_tokens:
|
||||
stop_outputs = []
|
||||
else:
|
||||
stop_outputs = None
|
||||
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
||||
decoder_input = decoder_inputs[len(mel_outputs)]
|
||||
# decoder_input_pitch = decoder_inputs_pitch[len(mel_outputs)]
|
||||
|
||||
decoder_rnn_input, context, attention_weights = self.attend(decoder_input)
|
||||
|
||||
decoder_rnn_output = self.decode(decoder_rnn_input)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
if self.use_stop_tokens:
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
mel_outputs += [mel_output.squeeze(1)] #? perhaps don't need squeeze
|
||||
alignments += [attention_weights]
|
||||
# alignments_pitch += [attention_weights_pitch]
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
if stop_outputs is None:
|
||||
return mel_outputs, alignments
|
||||
else:
|
||||
return mel_outputs, stop_outputs, alignments
|
||||
|
||||
def inference(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (1, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
|
||||
if torch.sigmoid(stop_output.data) > stop_threshold and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, _ = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, None)
|
||||
|
||||
return mel_outputs, alignments
|
||||
|
||||
def inference_batched(self, memory, stop_threshold=0.5):
|
||||
""" Decoder inference
|
||||
Args:
|
||||
memory: (B, T_enc, D_enc) Encoder outputs
|
||||
Returns:
|
||||
mel_outputs: mel outputs from the decoder
|
||||
alignments: sequence of attention weights from the decoder
|
||||
"""
|
||||
# [1, num_mels]
|
||||
decoder_input = self.get_go_frame(memory)
|
||||
|
||||
self.initialize_decoder_states(memory, mask=None)
|
||||
|
||||
self.attention_layer.init_states(memory)
|
||||
|
||||
mel_outputs, alignments = [], []
|
||||
stop_outputs = []
|
||||
# NOTE(sx): heuristic
|
||||
max_decoder_step = memory.size(1)*self.encoder_down_factor//self.frames_per_step
|
||||
min_decoder_step = memory.size(1)*self.encoder_down_factor // self.frames_per_step - 5
|
||||
while True:
|
||||
decoder_input = self.prenet(decoder_input)
|
||||
|
||||
decoder_input_final, context, alignment = self.attend(decoder_input)
|
||||
|
||||
#mel_output, stop_output, alignment = self.decode(decoder_input)
|
||||
decoder_rnn_output = self.decode(decoder_input_final)
|
||||
if self.concat_context_to_last:
|
||||
decoder_rnn_output = torch.cat(
|
||||
(decoder_rnn_output, context), dim=1)
|
||||
|
||||
mel_output = self.linear_projection(decoder_rnn_output)
|
||||
# (B, 1)
|
||||
stop_output = self.stop_layer(decoder_rnn_output)
|
||||
stop_outputs += [stop_output.squeeze()]
|
||||
# stop_outputs.append(stop_output)
|
||||
|
||||
mel_outputs += [mel_output.squeeze(1)]
|
||||
alignments += [alignment]
|
||||
# print(stop_output.shape)
|
||||
if torch.all(torch.sigmoid(stop_output.squeeze().data) > stop_threshold) \
|
||||
and len(mel_outputs) >= min_decoder_step:
|
||||
break
|
||||
if len(mel_outputs) >= max_decoder_step:
|
||||
# print("Warning! Decoding steps reaches max decoder steps.")
|
||||
break
|
||||
|
||||
decoder_input = mel_output[:,-self.num_mels:]
|
||||
|
||||
|
||||
mel_outputs, alignments, stop_outputs = self.parse_decoder_outputs(
|
||||
mel_outputs, alignments, stop_outputs)
|
||||
mel_outputs_stacked = []
|
||||
for mel, stop_logit in zip(mel_outputs, stop_outputs):
|
||||
idx = np.argwhere(torch.sigmoid(stop_logit.cpu()) > stop_threshold)[0][0].item()
|
||||
mel_outputs_stacked.append(mel[:idx,:])
|
||||
mel_outputs = torch.cat(mel_outputs_stacked, dim=0).unsqueeze(0)
|
||||
return mel_outputs, alignments
|
67
ppg2mel/train.py
Normal file
67
ppg2mel/train.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ckpt/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
ppg2mel/train/__init__.py
Normal file
1
ppg2mel/train/__init__.py
Normal file
|
@ -0,0 +1 @@
|
|||
#
|
50
ppg2mel/train/loss.py
Normal file
50
ppg2mel/train/loss.py
Normal file
|
@ -0,0 +1,50 @@
|
|||
from typing import Dict
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils.nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class MaskedMSELoss(nn.Module):
|
||||
def __init__(self, frames_per_step):
|
||||
super().__init__()
|
||||
self.frames_per_step = frames_per_step
|
||||
self.mel_loss_criterion = nn.MSELoss(reduction='none')
|
||||
# self.loss = nn.MSELoss()
|
||||
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
|
||||
|
||||
def get_mask(self, lengths, max_len=None):
|
||||
# lengths: [B,]
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths)
|
||||
batch_size = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len).long()
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
|
||||
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
|
||||
return (seq_range_expand < seq_length_expand).float()
|
||||
|
||||
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
|
||||
stop_target, stop_pred):
|
||||
## process stop_target
|
||||
B = stop_target.size(0)
|
||||
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
|
||||
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
|
||||
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
|
||||
|
||||
mel_trg.requires_grad = False
|
||||
# (B, T, 1)
|
||||
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
|
||||
# (B, T, D)
|
||||
mel_mask = mel_mask.expand_as(mel_trg)
|
||||
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||||
|
||||
mel_loss = mel_loss_pre + mel_loss_post
|
||||
|
||||
# stop token loss
|
||||
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
|
||||
|
||||
return mel_loss, stop_loss
|
45
ppg2mel/train/optim.py
Normal file
45
ppg2mel/train/optim.py
Normal file
|
@ -0,0 +1,45 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class Optimizer():
|
||||
def __init__(self, parameters, optimizer, lr, eps, lr_scheduler,
|
||||
**kwargs):
|
||||
|
||||
# Setup torch optimizer
|
||||
self.opt_type = optimizer
|
||||
self.init_lr = lr
|
||||
self.sch_type = lr_scheduler
|
||||
opt = getattr(torch.optim, optimizer)
|
||||
if lr_scheduler == 'warmup':
|
||||
warmup_step = 4000.0
|
||||
init_lr = lr
|
||||
self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
|
||||
np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
|
||||
self.opt = opt(parameters, lr=1.0)
|
||||
else:
|
||||
self.lr_scheduler = None
|
||||
self.opt = opt(parameters, lr=lr, eps=eps) # ToDo: 1e-8 better?
|
||||
|
||||
def get_opt_state_dict(self):
|
||||
return self.opt.state_dict()
|
||||
|
||||
def load_opt_state_dict(self, state_dict):
|
||||
self.opt.load_state_dict(state_dict)
|
||||
|
||||
def pre_step(self, step):
|
||||
if self.lr_scheduler is not None:
|
||||
cur_lr = self.lr_scheduler(step)
|
||||
for param_group in self.opt.param_groups:
|
||||
param_group['lr'] = cur_lr
|
||||
else:
|
||||
cur_lr = self.init_lr
|
||||
self.opt.zero_grad()
|
||||
return cur_lr
|
||||
|
||||
def step(self):
|
||||
self.opt.step()
|
||||
|
||||
def create_msg(self):
|
||||
return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
|
||||
.format(self.opt_type, self.init_lr, self.sch_type)]
|
10
ppg2mel/train/option.py
Normal file
10
ppg2mel/train/option.py
Normal file
|
@ -0,0 +1,10 @@
|
|||
# Default parameters which will be imported by solver
|
||||
default_hparas = {
|
||||
'GRAD_CLIP': 5.0, # Grad. clip threshold
|
||||
'PROGRESS_STEP': 100, # Std. output refresh freq.
|
||||
# Decode steps for objective validation (step = ratio*input_txt_len)
|
||||
'DEV_STEP_RATIO': 1.2,
|
||||
# Number of examples (alignment/text) to show in tensorboard
|
||||
'DEV_N_EXAMPLE': 4,
|
||||
'TB_FLUSH_FREQ': 180 # Update frequency of tensorboard (secs)
|
||||
}
|
216
ppg2mel/train/solver.py
Normal file
216
ppg2mel/train/solver.py
Normal file
|
@ -0,0 +1,216 @@
|
|||
import os
|
||||
import sys
|
||||
import abc
|
||||
import math
|
||||
import yaml
|
||||
import torch
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from .option import default_hparas
|
||||
from utils.util import human_format, Timer
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
|
||||
class BaseSolver():
|
||||
'''
|
||||
Prototype Solver for all kinds of tasks
|
||||
Arguments
|
||||
config - yaml-styled config
|
||||
paras - argparse outcome
|
||||
mode - "train"/"test"
|
||||
'''
|
||||
|
||||
def __init__(self, config, paras, mode="train"):
|
||||
# General Settings
|
||||
self.config = config # load from yaml file
|
||||
self.paras = paras # command line args
|
||||
self.mode = mode # 'train' or 'test'
|
||||
for k, v in default_hparas.items():
|
||||
setattr(self, k, v)
|
||||
self.device = torch.device('cuda') if self.paras.gpu and torch.cuda.is_available() \
|
||||
else torch.device('cpu')
|
||||
|
||||
# Name experiment
|
||||
self.exp_name = paras.name
|
||||
if self.exp_name is None:
|
||||
if 'exp_name' in self.config:
|
||||
self.exp_name = self.config.exp_name
|
||||
else:
|
||||
# By default, exp is named after config file
|
||||
self.exp_name = paras.config.split('/')[-1].replace('.yaml', '')
|
||||
if mode == 'train':
|
||||
self.exp_name += '_seed{}'.format(paras.seed)
|
||||
|
||||
|
||||
if mode == 'train':
|
||||
# Filepath setup
|
||||
os.makedirs(paras.ckpdir, exist_ok=True)
|
||||
self.ckpdir = os.path.join(paras.ckpdir, self.exp_name)
|
||||
os.makedirs(self.ckpdir, exist_ok=True)
|
||||
|
||||
# Logger settings
|
||||
self.logdir = os.path.join(paras.logdir, self.exp_name)
|
||||
self.log = SummaryWriter(
|
||||
self.logdir, flush_secs=self.TB_FLUSH_FREQ)
|
||||
self.timer = Timer()
|
||||
|
||||
# Hyper-parameters
|
||||
self.step = 0
|
||||
self.valid_step = config.hparas.valid_step
|
||||
self.max_step = config.hparas.max_step
|
||||
|
||||
self.verbose('Exp. name : {}'.format(self.exp_name))
|
||||
self.verbose('Loading data... large corpus may took a while.')
|
||||
|
||||
# elif mode == 'test':
|
||||
# # Output path
|
||||
# os.makedirs(paras.outdir, exist_ok=True)
|
||||
# self.ckpdir = os.path.join(paras.outdir, self.exp_name)
|
||||
|
||||
# Load training config to get acoustic feat and build model
|
||||
# self.src_config = HpsYaml(config.src.config)
|
||||
# self.paras.load = config.src.ckpt
|
||||
|
||||
# self.verbose('Evaluating result of tr. config @ {}'.format(
|
||||
# config.src.config))
|
||||
|
||||
def backward(self, loss):
|
||||
'''
|
||||
Standard backward step with self.timer and debugger
|
||||
Arguments
|
||||
loss - the loss to perform loss.backward()
|
||||
'''
|
||||
self.timer.set()
|
||||
loss.backward()
|
||||
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.model.parameters(), self.GRAD_CLIP)
|
||||
if math.isnan(grad_norm):
|
||||
self.verbose('Error : grad norm is NaN @ step '+str(self.step))
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.timer.cnt('bw')
|
||||
return grad_norm
|
||||
|
||||
def load_ckpt(self):
|
||||
''' Load ckpt if --load option is specified '''
|
||||
if self.paras.load is not None:
|
||||
if self.paras.warm_start:
|
||||
self.verbose(f"Warm starting model from checkpoint {self.paras.load}.")
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
model_dict = ckpt['model']
|
||||
if len(self.config.model.ignore_layers) > 0:
|
||||
model_dict = {k:v for k, v in model_dict.items()
|
||||
if k not in self.config.model.ignore_layers}
|
||||
dummy_dict = self.model.state_dict()
|
||||
dummy_dict.update(model_dict)
|
||||
model_dict = dummy_dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
else:
|
||||
# Load weights
|
||||
ckpt = torch.load(
|
||||
self.paras.load, map_location=self.device if self.mode == 'train'
|
||||
else 'cpu')
|
||||
self.model.load_state_dict(ckpt['model'])
|
||||
|
||||
# Load task-dependent items
|
||||
if self.mode == 'train':
|
||||
self.step = ckpt['global_step']
|
||||
self.optimizer.load_opt_state_dict(ckpt['optimizer'])
|
||||
self.verbose('Load ckpt from {}, restarting at step {}'.format(
|
||||
self.paras.load, self.step))
|
||||
else:
|
||||
for k, v in ckpt.items():
|
||||
if type(v) is float:
|
||||
metric, score = k, v
|
||||
self.model.eval()
|
||||
self.verbose('Evaluation target = {} (recorded {} = {:.2f} %)'.format(
|
||||
self.paras.load, metric, score))
|
||||
|
||||
def verbose(self, msg):
|
||||
''' Verbose function for print information to stdout'''
|
||||
if self.paras.verbose:
|
||||
if type(msg) == list:
|
||||
for m in msg:
|
||||
print('[INFO]', m.ljust(100))
|
||||
else:
|
||||
print('[INFO]', msg.ljust(100))
|
||||
|
||||
def progress(self, msg):
|
||||
''' Verbose function for updating progress on stdout (do not include newline) '''
|
||||
if self.paras.verbose:
|
||||
sys.stdout.write("\033[K") # Clear line
|
||||
print('[{}] {}'.format(human_format(self.step), msg), end='\r')
|
||||
|
||||
def write_log(self, log_name, log_dict):
|
||||
'''
|
||||
Write log to TensorBoard
|
||||
log_name - <str> Name of tensorboard variable
|
||||
log_value - <dict>/<array> Value of variable (e.g. dict of losses), passed if value = None
|
||||
'''
|
||||
if type(log_dict) is dict:
|
||||
log_dict = {key: val for key, val in log_dict.items() if (
|
||||
val is not None and not math.isnan(val))}
|
||||
if log_dict is None:
|
||||
pass
|
||||
elif len(log_dict) > 0:
|
||||
if 'align' in log_name or 'spec' in log_name:
|
||||
img, form = log_dict
|
||||
self.log.add_image(
|
||||
log_name, img, global_step=self.step, dataformats=form)
|
||||
elif 'text' in log_name or 'hyp' in log_name:
|
||||
self.log.add_text(log_name, log_dict, self.step)
|
||||
else:
|
||||
self.log.add_scalars(log_name, log_dict, self.step)
|
||||
|
||||
def save_checkpoint(self, f_name, metric, score, show_msg=True):
|
||||
''''
|
||||
Ckpt saver
|
||||
f_name - <str> the name of ckpt file (w/o prefix) to store, overwrite if existed
|
||||
score - <float> The value of metric used to evaluate model
|
||||
'''
|
||||
ckpt_path = os.path.join(self.ckpdir, f_name)
|
||||
full_dict = {
|
||||
"model": self.model.state_dict(),
|
||||
"optimizer": self.optimizer.get_opt_state_dict(),
|
||||
"global_step": self.step,
|
||||
metric: score
|
||||
}
|
||||
|
||||
torch.save(full_dict, ckpt_path)
|
||||
if show_msg:
|
||||
self.verbose("Saved checkpoint (step = {}, {} = {:.2f}) and status @ {}".
|
||||
format(human_format(self.step), metric, score, ckpt_path))
|
||||
|
||||
|
||||
# ----------------------------------- Abtract Methods ------------------------------------------ #
|
||||
@abc.abstractmethod
|
||||
def load_data(self):
|
||||
'''
|
||||
Called by main to load all data
|
||||
After this call, data related attributes should be setup (e.g. self.tr_set, self.dev_set)
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_model(self):
|
||||
'''
|
||||
Called by main to set models
|
||||
After this call, model related attributes should be setup (e.g. self.l2_loss)
|
||||
The followings MUST be setup
|
||||
- self.model (torch.nn.Module)
|
||||
- self.optimizer (src.Optimizer),
|
||||
init. w/ self.optimizer = src.Optimizer(self.model.parameters(),**self.config['hparas'])
|
||||
Loading pre-trained model should also be performed here
|
||||
No return value
|
||||
'''
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def exec(self):
|
||||
'''
|
||||
Called by main to execute training/inference
|
||||
'''
|
||||
raise NotImplementedError
|
288
ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
288
ppg2mel/train/train_linglf02mel_seq2seq_oneshotvc.py
Normal file
|
@ -0,0 +1,288 @@
|
|||
import os, sys
|
||||
# sys.path.append('/home/shaunxliu/projects/nnsp')
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import numpy as np
|
||||
from .solver import BaseSolver
|
||||
from utils.data_load import OneshotVcDataset, MultiSpkVcCollate
|
||||
# from src.rnn_ppg2mel import BiRnnPpg2MelModel
|
||||
# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
|
||||
from .loss import MaskedMSELoss
|
||||
from .optim import Optimizer
|
||||
from utils.util import human_format
|
||||
from ppg2mel import MelDecoderMOLv2
|
||||
|
||||
|
||||
class Solver(BaseSolver):
|
||||
"""Customized Solver."""
|
||||
def __init__(self, config, paras, mode):
|
||||
super().__init__(config, paras, mode)
|
||||
self.num_att_plots = 5
|
||||
self.att_ws_dir = f"{self.logdir}/att_ws"
|
||||
os.makedirs(self.att_ws_dir, exist_ok=True)
|
||||
self.best_loss = np.inf
|
||||
|
||||
def fetch_data(self, data):
|
||||
"""Move data to device"""
|
||||
data = [i.to(self.device) for i in data]
|
||||
return data
|
||||
|
||||
def load_data(self):
|
||||
""" Load data for training/validation/plotting."""
|
||||
train_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.train_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
dev_dataset = OneshotVcDataset(
|
||||
meta_file=self.config.data.dev_fid_list,
|
||||
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
||||
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
||||
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
||||
libri_f0_dir=self.config.data.libri_f0_dir,
|
||||
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
||||
libri_wav_dir=self.config.data.libri_wav_dir,
|
||||
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
||||
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
||||
ppg_file_ext=self.config.data.ppg_file_ext,
|
||||
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
||||
mel_min=self.config.data.mel_min,
|
||||
mel_max=self.config.data.mel_max,
|
||||
)
|
||||
self.train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=True,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=True,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.dev_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=self.config.hparas.batch_size,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True),
|
||||
)
|
||||
self.plot_dataloader = DataLoader(
|
||||
dev_dataset,
|
||||
num_workers=self.paras.njobs,
|
||||
shuffle=False,
|
||||
batch_size=1,
|
||||
pin_memory=False,
|
||||
drop_last=False,
|
||||
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
||||
use_spk_dvec=True,
|
||||
give_uttids=True),
|
||||
)
|
||||
msg = "Have prepared training set and dev set."
|
||||
self.verbose(msg)
|
||||
|
||||
def load_pretrained_params(self):
|
||||
print("Load pretrained model from: ", self.config.data.pretrain_model_file)
|
||||
ignore_layer_prefixes = ["speaker_embedding_table"]
|
||||
pretrain_model_file = self.config.data.pretrain_model_file
|
||||
pretrain_ckpt = torch.load(
|
||||
pretrain_model_file, map_location=self.device
|
||||
)["model"]
|
||||
model_dict = self.model.state_dict()
|
||||
print(self.model)
|
||||
|
||||
# 1. filter out unnecessrary keys
|
||||
for prefix in ignore_layer_prefixes:
|
||||
pretrain_ckpt = {k : v
|
||||
for k, v in pretrain_ckpt.items() if not k.startswith(prefix)
|
||||
}
|
||||
# 2. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrain_ckpt)
|
||||
|
||||
# 3. load the new state dict
|
||||
self.model.load_state_dict(model_dict)
|
||||
|
||||
def set_model(self):
|
||||
"""Setup model and optimizer"""
|
||||
# Model
|
||||
print("[INFO] Model name: ", self.config["model_name"])
|
||||
self.model = MelDecoderMOLv2(
|
||||
**self.config["model"]
|
||||
).to(self.device)
|
||||
# self.load_pretrained_params()
|
||||
|
||||
# model_params = [{'params': self.model.spk_embedding.weight}]
|
||||
model_params = [{'params': self.model.parameters()}]
|
||||
|
||||
# Loss criterion
|
||||
self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step)
|
||||
|
||||
# Optimizer
|
||||
self.optimizer = Optimizer(model_params, **self.config["hparas"])
|
||||
self.verbose(self.optimizer.create_msg())
|
||||
|
||||
# Automatically load pre-trained model if self.paras.load is given
|
||||
self.load_ckpt()
|
||||
|
||||
def exec(self):
|
||||
self.verbose("Total training steps {}.".format(
|
||||
human_format(self.max_step)))
|
||||
|
||||
mel_loss = None
|
||||
n_epochs = 0
|
||||
# Set as current time
|
||||
self.timer.set()
|
||||
|
||||
while self.step < self.max_step:
|
||||
for data in self.train_dataloader:
|
||||
# Pre-step: updata lr_rate and do zero_grad
|
||||
lr_rate = self.optimizer.pre_step(self.step)
|
||||
total_loss = 0
|
||||
# data to device
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
self.timer.cnt("rd")
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
self.timer.cnt("fw")
|
||||
|
||||
# Back-prop
|
||||
grad_norm = self.backward(loss)
|
||||
self.step += 1
|
||||
|
||||
# Logger
|
||||
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
|
||||
self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}"
|
||||
.format(loss.cpu().item(), mel_loss.cpu().item(),
|
||||
stop_loss.cpu().item(), grad_norm, self.timer.show()))
|
||||
self.write_log('loss', {'tr/loss': loss,
|
||||
'tr/mel-loss': mel_loss,
|
||||
'tr/stop-loss': stop_loss})
|
||||
|
||||
# Validation
|
||||
if (self.step == 1) or (self.step % self.valid_step == 0):
|
||||
self.validate()
|
||||
|
||||
# End of step
|
||||
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
|
||||
torch.cuda.empty_cache()
|
||||
self.timer.set()
|
||||
if self.step > self.max_step:
|
||||
break
|
||||
n_epochs += 1
|
||||
self.log.close()
|
||||
|
||||
def validate(self):
|
||||
self.model.eval()
|
||||
dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0
|
||||
|
||||
for i, data in enumerate(self.dev_dataloader):
|
||||
self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
||||
with torch.no_grad():
|
||||
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids
|
||||
)
|
||||
mel_loss, stop_loss = self.loss_criterion(
|
||||
mel_outputs,
|
||||
mel_outputs_postnet,
|
||||
mels,
|
||||
out_lengths,
|
||||
stop_tokens,
|
||||
predicted_stop
|
||||
)
|
||||
loss = mel_loss + stop_loss
|
||||
|
||||
dev_loss += loss.cpu().item()
|
||||
dev_mel_loss += mel_loss.cpu().item()
|
||||
dev_stop_loss += stop_loss.cpu().item()
|
||||
|
||||
dev_loss = dev_loss / (i + 1)
|
||||
dev_mel_loss = dev_mel_loss / (i + 1)
|
||||
dev_stop_loss = dev_stop_loss / (i + 1)
|
||||
self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
|
||||
if dev_loss < self.best_loss:
|
||||
self.best_loss = dev_loss
|
||||
self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
|
||||
self.write_log('loss', {'dv/loss': dev_loss,
|
||||
'dv/mel-loss': dev_mel_loss,
|
||||
'dv/stop-loss': dev_stop_loss})
|
||||
|
||||
# plot attention
|
||||
for i, data in enumerate(self.plot_dataloader):
|
||||
if i == self.num_att_plots:
|
||||
break
|
||||
# Fetch data
|
||||
ppgs, lf0_uvs, mels, in_lengths, \
|
||||
out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1])
|
||||
fid = data[-1][0]
|
||||
with torch.no_grad():
|
||||
_, _, _, att_ws = self.model(
|
||||
ppgs,
|
||||
in_lengths,
|
||||
mels,
|
||||
out_lengths,
|
||||
lf0_uvs,
|
||||
spk_ids,
|
||||
output_att_ws=True
|
||||
)
|
||||
att_ws = att_ws.squeeze(0).cpu().numpy()
|
||||
att_ws = att_ws[None]
|
||||
w, h = plt.figaspect(1.0 / len(att_ws))
|
||||
fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
|
||||
axes = fig.subplots(1, len(att_ws))
|
||||
if len(att_ws) == 1:
|
||||
axes = [axes]
|
||||
|
||||
for ax, aw in zip(axes, att_ws):
|
||||
ax.imshow(aw.astype(np.float32), aspect="auto")
|
||||
ax.set_title(f"{fid}")
|
||||
ax.set_xlabel("Input")
|
||||
ax.set_ylabel("Output")
|
||||
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
||||
fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png"
|
||||
fig.savefig(fig_name)
|
||||
|
||||
# Resume training
|
||||
self.model.train()
|
||||
|
23
ppg2mel/utils/abs_model.py
Normal file
23
ppg2mel/utils/abs_model.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
class AbsMelDecoder(torch.nn.Module, ABC):
|
||||
"""The abstract PPG-based voice conversion class
|
||||
This "model" is one of mediator objects for "Task" class.
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(
|
||||
self,
|
||||
bottle_neck_features: torch.Tensor,
|
||||
feature_lengths: torch.Tensor,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
logf0_uv: torch.Tensor = None,
|
||||
spembs: torch.Tensor = None,
|
||||
styleembs: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
79
ppg2mel/utils/basic_layers.py
Normal file
79
ppg2mel/utils/basic_layers.py
Normal file
|
@ -0,0 +1,79 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
||||
super(Linear, self).__init__()
|
||||
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.linear_layer.weight,
|
||||
gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear_layer(x)
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
||||
padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
|
||||
super(Conv1d, self).__init__()
|
||||
if padding is None:
|
||||
assert(kernel_size % 2 == 1)
|
||||
padding = int(dilation * (kernel_size - 1)/2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, dilation=dilation,
|
||||
bias=bias)
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
|
||||
|
||||
def forward(self, x):
|
||||
# x: BxDxT
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
|
||||
def tile(x, count, dim=0):
|
||||
"""
|
||||
Tiles x on dimension dim count times.
|
||||
"""
|
||||
perm = list(range(len(x.size())))
|
||||
if dim != 0:
|
||||
perm[0], perm[dim] = perm[dim], perm[0]
|
||||
x = x.permute(perm).contiguous()
|
||||
out_size = list(x.size())
|
||||
out_size[0] *= count
|
||||
batch = x.size(0)
|
||||
x = x.view(batch, -1) \
|
||||
.transpose(0, 1) \
|
||||
.repeat(count, 1) \
|
||||
.transpose(0, 1) \
|
||||
.contiguous() \
|
||||
.view(*out_size)
|
||||
if dim != 0:
|
||||
x = x.permute(perm).contiguous()
|
||||
return x
|
52
ppg2mel/utils/cnn_postnet.py
Normal file
52
ppg2mel/utils/cnn_postnet.py
Normal file
|
@ -0,0 +1,52 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .basic_layers import Linear, Conv1d
|
||||
|
||||
|
||||
class Postnet(nn.Module):
|
||||
"""Postnet
|
||||
- Five 1-d convolution with 512 channels and kernel size 5
|
||||
"""
|
||||
def __init__(self, num_mels=80,
|
||||
num_layers=5,
|
||||
hidden_dim=512,
|
||||
kernel_size=5):
|
||||
super(Postnet, self).__init__()
|
||||
self.convolutions = nn.ModuleList()
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
num_mels, hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
for i in range(1, num_layers - 1):
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim,
|
||||
hidden_dim,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='tanh'),
|
||||
nn.BatchNorm1d(hidden_dim)))
|
||||
|
||||
self.convolutions.append(
|
||||
nn.Sequential(
|
||||
Conv1d(
|
||||
hidden_dim, num_mels,
|
||||
kernel_size=kernel_size, stride=1,
|
||||
padding=int((kernel_size - 1) / 2),
|
||||
dilation=1, w_init_gain='linear'),
|
||||
nn.BatchNorm1d(num_mels)))
|
||||
|
||||
def forward(self, x):
|
||||
# x: (B, num_mels, T_dec)
|
||||
for i in range(len(self.convolutions) - 1):
|
||||
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
||||
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
||||
return x
|
123
ppg2mel/utils/mol_attention.py
Normal file
123
ppg2mel/utils/mol_attention.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class MOLAttention(nn.Module):
|
||||
""" Discretized Mixture of Logistic (MOL) attention.
|
||||
C.f. Section 5 of "MelNet: A Generative Model for Audio in the Frequency Domain" and
|
||||
GMMv2b model in "Location-relative attention mechanisms for robust long-form speech synthesis".
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
query_dim,
|
||||
r=1,
|
||||
M=5,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
query_dim: attention_rnn_dim.
|
||||
M: number of mixtures.
|
||||
"""
|
||||
super().__init__()
|
||||
if r < 1:
|
||||
self.r = float(r)
|
||||
else:
|
||||
self.r = int(r)
|
||||
self.M = M
|
||||
self.score_mask_value = 0.0 # -float("inf")
|
||||
self.eps = 1e-5
|
||||
# Position arrary for encoder time steps
|
||||
self.J = None
|
||||
# Query layer: [w, sigma,]
|
||||
self.query_layer = torch.nn.Sequential(
|
||||
nn.Linear(query_dim, 256, bias=True),
|
||||
nn.ReLU(),
|
||||
nn.Linear(256, 3*M, bias=True)
|
||||
)
|
||||
self.mu_prev = None
|
||||
self.initialize_bias()
|
||||
|
||||
def initialize_bias(self):
|
||||
"""Initialize sigma and Delta."""
|
||||
# sigma
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[self.M:2*self.M], 1.0)
|
||||
# Delta: softplus(1.8545) = 2.0; softplus(3.9815) = 4.0; softplus(0.5413) = 1.0
|
||||
# softplus(-0.432) = 0.5003
|
||||
if self.r == 2:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 1.8545)
|
||||
elif self.r == 4:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 3.9815)
|
||||
elif self.r == 1:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], 0.5413)
|
||||
else:
|
||||
torch.nn.init.constant_(self.query_layer[2].bias[2*self.M:3*self.M], -0.432)
|
||||
|
||||
|
||||
def init_states(self, memory):
|
||||
"""Initialize mu_prev and J.
|
||||
This function should be called by the decoder before decoding one batch.
|
||||
Args:
|
||||
memory: (B, T, D_enc) encoder output.
|
||||
"""
|
||||
B, T_enc, _ = memory.size()
|
||||
device = memory.device
|
||||
self.J = torch.arange(0, T_enc + 2.0).to(device) + 0.5 # NOTE: for discretize usage
|
||||
# self.J = memory.new_tensor(np.arange(T_enc), dtype=torch.float)
|
||||
self.mu_prev = torch.zeros(B, self.M).to(device)
|
||||
|
||||
def forward(self, att_rnn_h, memory, memory_pitch=None, mask=None):
|
||||
"""
|
||||
att_rnn_h: attetion rnn hidden state.
|
||||
memory: encoder outputs (B, T_enc, D).
|
||||
mask: binary mask for padded data (B, T_enc).
|
||||
"""
|
||||
# [B, 3M]
|
||||
mixture_params = self.query_layer(att_rnn_h)
|
||||
|
||||
# [B, M]
|
||||
w_hat = mixture_params[:, :self.M]
|
||||
sigma_hat = mixture_params[:, self.M:2*self.M]
|
||||
Delta_hat = mixture_params[:, 2*self.M:3*self.M]
|
||||
|
||||
# print("w_hat: ", w_hat)
|
||||
# print("sigma_hat: ", sigma_hat)
|
||||
# print("Delta_hat: ", Delta_hat)
|
||||
|
||||
# Dropout to de-correlate attention heads
|
||||
w_hat = F.dropout(w_hat, p=0.5, training=self.training) # NOTE(sx): needed?
|
||||
|
||||
# Mixture parameters
|
||||
w = torch.softmax(w_hat, dim=-1) + self.eps
|
||||
sigma = F.softplus(sigma_hat) + self.eps
|
||||
Delta = F.softplus(Delta_hat)
|
||||
mu_cur = self.mu_prev + Delta
|
||||
# print("w:", w)
|
||||
j = self.J[:memory.size(1) + 1]
|
||||
|
||||
# Attention weights
|
||||
# CDF of logistic distribution
|
||||
phi_t = w.unsqueeze(-1) * (1 / (1 + torch.sigmoid(
|
||||
(mu_cur.unsqueeze(-1) - j) / sigma.unsqueeze(-1))))
|
||||
# print("phi_t:", phi_t)
|
||||
|
||||
# Discretize attention weights
|
||||
# (B, T_enc + 1)
|
||||
alpha_t = torch.sum(phi_t, dim=1)
|
||||
alpha_t = alpha_t[:, 1:] - alpha_t[:, :-1]
|
||||
alpha_t[alpha_t == 0] = self.eps
|
||||
# print("alpha_t: ", alpha_t.size())
|
||||
# Apply masking
|
||||
if mask is not None:
|
||||
alpha_t.data.masked_fill_(mask, self.score_mask_value)
|
||||
|
||||
context = torch.bmm(alpha_t.unsqueeze(1), memory).squeeze(1)
|
||||
if memory_pitch is not None:
|
||||
context_pitch = torch.bmm(alpha_t.unsqueeze(1), memory_pitch).squeeze(1)
|
||||
|
||||
self.mu_prev = mu_cur
|
||||
|
||||
if memory_pitch is not None:
|
||||
return context, context_pitch, alpha_t
|
||||
return context, alpha_t
|
||||
|
451
ppg2mel/utils/nets_utils.py
Normal file
451
ppg2mel/utils/nets_utils.py
Normal file
|
@ -0,0 +1,451 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
assert isinstance(m, torch.nn.Module)
|
||||
device = next(m.parameters()).device
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||
for i in range(xs.dim()))
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0),
|
||||
pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == 'c':
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if 'real' not in x or 'imag' not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x['real'], x['imag'])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x)))
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == 'transformer':
|
||||
return np.array([1])
|
||||
|
||||
elif mode == 'mt' and arch == 'rnn':
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning('Subsampling is not performed for machine translation.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||
(mode == 'mt' and arch == 'rnn') or \
|
||||
(mode == 'st' and arch == 'rnn'):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mix':
|
||||
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
22
ppg2mel/utils/vc_utils.py
Normal file
22
ppg2mel/utils/vc_utils.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
import torch
|
||||
|
||||
|
||||
def gcd(a, b):
|
||||
"""Greatest common divisor."""
|
||||
a, b = (a, b) if a >=b else (b, a)
|
||||
if a%b == 0:
|
||||
return b
|
||||
else :
|
||||
return gcd(b, a%b)
|
||||
|
||||
def lcm(a, b):
|
||||
"""Least common multiple"""
|
||||
return a * b // gcd(a, b)
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
||||
mask = (ids < lengths.unsqueeze(1)).bool()
|
||||
return mask
|
||||
|
102
ppg_extractor/__init__.py
Normal file
102
ppg_extractor/__init__.py
Normal file
|
@ -0,0 +1,102 @@
|
|||
import argparse
|
||||
import torch
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
|
||||
from .frontend import DefaultFrontend
|
||||
from .utterance_mvn import UtteranceMVN
|
||||
from .encoder.conformer_encoder import ConformerEncoder
|
||||
|
||||
_model = None # type: PPGModel
|
||||
_device = None
|
||||
|
||||
class PPGModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
frontend,
|
||||
normalizer,
|
||||
encoder,
|
||||
):
|
||||
super().__init__()
|
||||
self.frontend = frontend
|
||||
self.normalize = normalizer
|
||||
self.encoder = encoder
|
||||
|
||||
def forward(self, speech, speech_lengths):
|
||||
"""
|
||||
|
||||
Args:
|
||||
speech (tensor): (B, L)
|
||||
speech_lengths (tensor): (B, )
|
||||
|
||||
Returns:
|
||||
bottle_neck_feats (tensor): (B, L//hop_size, 144)
|
||||
|
||||
"""
|
||||
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||||
return encoder_out
|
||||
|
||||
def _extract_feats(
|
||||
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||
):
|
||||
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||
|
||||
# for data-parallel
|
||||
speech = speech[:, : speech_lengths.max()]
|
||||
|
||||
if self.frontend is not None:
|
||||
# Frontend
|
||||
# e.g. STFT and Feature extract
|
||||
# data_loader may send time-domain signal in this case
|
||||
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||
else:
|
||||
# No frontend and no feature extract
|
||||
feats, feats_lengths = speech, speech_lengths
|
||||
return feats, feats_lengths
|
||||
|
||||
def extract_from_wav(self, src_wav):
|
||||
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device)
|
||||
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device)
|
||||
return self(src_wav_tensor, src_wav_lengths)
|
||||
|
||||
|
||||
def build_model(args):
|
||||
normalizer = UtteranceMVN(**args.normalize_conf)
|
||||
frontend = DefaultFrontend(**args.frontend_conf)
|
||||
encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
|
||||
model = PPGModel(frontend, normalizer, encoder)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model_file, device=None):
|
||||
global _model, _device
|
||||
|
||||
if device is None:
|
||||
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
_device = device
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||||
config_file = model_config_fpaths[0]
|
||||
with config_file.open("r", encoding="utf-8") as f:
|
||||
args = yaml.safe_load(f)
|
||||
|
||||
args = argparse.Namespace(**args)
|
||||
|
||||
model = build_model(args)
|
||||
model_state_dict = model.state_dict()
|
||||
|
||||
ckpt_state_dict = torch.load(model_file, map_location=_device)
|
||||
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
|
||||
|
||||
model_state_dict.update(ckpt_state_dict)
|
||||
model.load_state_dict(model_state_dict)
|
||||
|
||||
_model = model.eval().to(_device)
|
||||
return _model
|
||||
|
||||
|
398
ppg_extractor/e2e_asr_common.py
Normal file
398
ppg_extractor/e2e_asr_common.py
Normal file
|
@ -0,0 +1,398 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Common functions for ASR."""
|
||||
|
||||
import argparse
|
||||
import editdistance
|
||||
import json
|
||||
import logging
|
||||
import numpy as np
|
||||
import six
|
||||
import sys
|
||||
|
||||
from itertools import groupby
|
||||
|
||||
|
||||
def end_detect(ended_hyps, i, M=3, D_end=np.log(1 * np.exp(-10))):
|
||||
"""End detection.
|
||||
|
||||
desribed in Eq. (50) of S. Watanabe et al
|
||||
"Hybrid CTC/Attention Architecture for End-to-End Speech Recognition"
|
||||
|
||||
:param ended_hyps:
|
||||
:param i:
|
||||
:param M:
|
||||
:param D_end:
|
||||
:return:
|
||||
"""
|
||||
if len(ended_hyps) == 0:
|
||||
return False
|
||||
count = 0
|
||||
best_hyp = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[0]
|
||||
for m in six.moves.range(M):
|
||||
# get ended_hyps with their length is i - m
|
||||
hyp_length = i - m
|
||||
hyps_same_length = [x for x in ended_hyps if len(x['yseq']) == hyp_length]
|
||||
if len(hyps_same_length) > 0:
|
||||
best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x['score'], reverse=True)[0]
|
||||
if best_hyp_same_length['score'] - best_hyp['score'] < D_end:
|
||||
count += 1
|
||||
|
||||
if count == M:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
# TODO(takaaki-hori): add different smoothing methods
|
||||
def label_smoothing_dist(odim, lsm_type, transcript=None, blank=0):
|
||||
"""Obtain label distribution for loss smoothing.
|
||||
|
||||
:param odim:
|
||||
:param lsm_type:
|
||||
:param blank:
|
||||
:param transcript:
|
||||
:return:
|
||||
"""
|
||||
if transcript is not None:
|
||||
with open(transcript, 'rb') as f:
|
||||
trans_json = json.load(f)['utts']
|
||||
|
||||
if lsm_type == 'unigram':
|
||||
assert transcript is not None, 'transcript is required for %s label smoothing' % lsm_type
|
||||
labelcount = np.zeros(odim)
|
||||
for k, v in trans_json.items():
|
||||
ids = np.array([int(n) for n in v['output'][0]['tokenid'].split()])
|
||||
# to avoid an error when there is no text in an uttrance
|
||||
if len(ids) > 0:
|
||||
labelcount[ids] += 1
|
||||
labelcount[odim - 1] = len(transcript) # count <eos>
|
||||
labelcount[labelcount == 0] = 1 # flooring
|
||||
labelcount[blank] = 0 # remove counts for blank
|
||||
labeldist = labelcount.astype(np.float32) / np.sum(labelcount)
|
||||
else:
|
||||
logging.error(
|
||||
"Error: unexpected label smoothing type: %s" % lsm_type)
|
||||
sys.exit()
|
||||
|
||||
return labeldist
|
||||
|
||||
|
||||
def get_vgg2l_odim(idim, in_channel=3, out_channel=128, downsample=True):
|
||||
"""Return the output size of the VGG frontend.
|
||||
|
||||
:param in_channel: input channel size
|
||||
:param out_channel: output channel size
|
||||
:return: output size
|
||||
:rtype int
|
||||
"""
|
||||
idim = idim / in_channel
|
||||
if downsample:
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 1st max pooling
|
||||
idim = np.ceil(np.array(idim, dtype=np.float32) / 2) # 2nd max pooling
|
||||
return int(idim) * out_channel # numer of channels
|
||||
|
||||
|
||||
class ErrorCalculator(object):
|
||||
"""Calculate CER and WER for E2E_ASR and CTC models during training.
|
||||
|
||||
:param y_hats: numpy array with predicted text
|
||||
:param y_pads: numpy array with true (target) text
|
||||
:param char_list:
|
||||
:param sym_space:
|
||||
:param sym_blank:
|
||||
:return:
|
||||
"""
|
||||
|
||||
def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False,
|
||||
trans_type="char"):
|
||||
"""Construct an ErrorCalculator object."""
|
||||
super(ErrorCalculator, self).__init__()
|
||||
|
||||
self.report_cer = report_cer
|
||||
self.report_wer = report_wer
|
||||
self.trans_type = trans_type
|
||||
self.char_list = char_list
|
||||
self.space = sym_space
|
||||
self.blank = sym_blank
|
||||
self.idx_blank = self.char_list.index(self.blank)
|
||||
if self.space in self.char_list:
|
||||
self.idx_space = self.char_list.index(self.space)
|
||||
else:
|
||||
self.idx_space = None
|
||||
|
||||
def __call__(self, ys_hat, ys_pad, is_ctc=False):
|
||||
"""Calculate sentence-level WER/CER score.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:param bool is_ctc: calculate CER score for CTC
|
||||
:return: sentence-level WER score
|
||||
:rtype float
|
||||
:return: sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cer, wer = None, None
|
||||
if is_ctc:
|
||||
return self.calculate_cer_ctc(ys_hat, ys_pad)
|
||||
elif not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad)
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
return cer, wer
|
||||
|
||||
def calculate_cer_ctc(self, ys_hat, ys_pad):
|
||||
"""Calculate sentence-level CER score for CTC.
|
||||
|
||||
:param torch.Tensor ys_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor ys_pad: reference (batch, seqlen)
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
cers, char_ref_lens = [], []
|
||||
for i, y in enumerate(ys_hat):
|
||||
y_hat = [x[0] for x in groupby(y)]
|
||||
y_true = ys_pad[i]
|
||||
seq_hat, seq_true = [], []
|
||||
for idx in y_hat:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_hat.append(self.char_list[int(idx)])
|
||||
|
||||
for idx in y_true:
|
||||
idx = int(idx)
|
||||
if idx != -1 and idx != self.idx_blank and idx != self.idx_space:
|
||||
seq_true.append(self.char_list[int(idx)])
|
||||
if self.trans_type == "char":
|
||||
hyp_chars = "".join(seq_hat)
|
||||
ref_chars = "".join(seq_true)
|
||||
else:
|
||||
hyp_chars = " ".join(seq_hat)
|
||||
ref_chars = " ".join(seq_true)
|
||||
|
||||
if len(ref_chars) > 0:
|
||||
cers.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None
|
||||
return cer_ctc
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
:param torch.Tensor seqs_hat: prediction (batch, seqlen)
|
||||
:param torch.Tensor seqs_true: reference (batch, seqlen)
|
||||
:return: token list of prediction
|
||||
:rtype list
|
||||
:return: token list of reference
|
||||
:rtype list
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
# To avoid wrong higher WER than the one obtained from the decoding
|
||||
# eos from y_true is used to mark the eos in y_hat
|
||||
# because of that y_hats has not padded outs with -1.
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
# seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = " ".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
# seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
seq_true_text = " ".join(seq_true).replace(self.space, ' ')
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level CER score
|
||||
:rtype float
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score.
|
||||
|
||||
:param list seqs_hat: prediction
|
||||
:param list seqs_true: reference
|
||||
:return: average sentence-level WER score
|
||||
:rtype float
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
||||
|
||||
|
||||
class ErrorCalculatorTrans(object):
|
||||
"""Calculate CER and WER for transducer models.
|
||||
|
||||
Args:
|
||||
decoder (nn.Module): decoder module
|
||||
args (Namespace): argument Namespace containing options
|
||||
report_cer (boolean): compute CER option
|
||||
report_wer (boolean): compute WER option
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, decoder, args, report_cer=False, report_wer=False):
|
||||
"""Construct an ErrorCalculator object for transducer model."""
|
||||
super(ErrorCalculatorTrans, self).__init__()
|
||||
|
||||
self.dec = decoder
|
||||
|
||||
recog_args = {'beam_size': args.beam_size,
|
||||
'nbest': args.nbest,
|
||||
'space': args.sym_space,
|
||||
'score_norm_transducer': args.score_norm_transducer}
|
||||
|
||||
self.recog_args = argparse.Namespace(**recog_args)
|
||||
|
||||
self.char_list = args.char_list
|
||||
self.space = args.sym_space
|
||||
self.blank = args.sym_blank
|
||||
|
||||
self.report_cer = args.report_cer
|
||||
self.report_wer = args.report_wer
|
||||
|
||||
def __call__(self, hs_pad, ys_pad):
|
||||
"""Calculate sentence-level WER/CER score for transducer models.
|
||||
|
||||
Args:
|
||||
hs_pad (torch.Tensor): batch of padded input sequence (batch, T, D)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): sentence-level CER score
|
||||
(float): sentence-level WER score
|
||||
|
||||
"""
|
||||
cer, wer = None, None
|
||||
|
||||
if not self.report_cer and not self.report_wer:
|
||||
return cer, wer
|
||||
|
||||
batchsize = int(hs_pad.size(0))
|
||||
batch_nbest = []
|
||||
|
||||
for b in six.moves.range(batchsize):
|
||||
if self.recog_args.beam_size == 1:
|
||||
nbest_hyps = self.dec.recognize(hs_pad[b], self.recog_args)
|
||||
else:
|
||||
nbest_hyps = self.dec.recognize_beam(hs_pad[b], self.recog_args)
|
||||
batch_nbest.append(nbest_hyps)
|
||||
|
||||
ys_hat = [nbest_hyp[0]['yseq'][1:] for nbest_hyp in batch_nbest]
|
||||
|
||||
seqs_hat, seqs_true = self.convert_to_char(ys_hat, ys_pad.cpu())
|
||||
|
||||
if self.report_cer:
|
||||
cer = self.calculate_cer(seqs_hat, seqs_true)
|
||||
|
||||
if self.report_wer:
|
||||
wer = self.calculate_wer(seqs_hat, seqs_true)
|
||||
|
||||
return cer, wer
|
||||
|
||||
def convert_to_char(self, ys_hat, ys_pad):
|
||||
"""Convert index to character.
|
||||
|
||||
Args:
|
||||
ys_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
ys_pad (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(list): token list of prediction
|
||||
(list): token list of reference
|
||||
|
||||
"""
|
||||
seqs_hat, seqs_true = [], []
|
||||
|
||||
for i, y_hat in enumerate(ys_hat):
|
||||
y_true = ys_pad[i]
|
||||
|
||||
eos_true = np.where(y_true == -1)[0]
|
||||
eos_true = eos_true[0] if len(eos_true) > 0 else len(y_true)
|
||||
|
||||
seq_hat = [self.char_list[int(idx)] for idx in y_hat[:eos_true]]
|
||||
seq_true = [self.char_list[int(idx)] for idx in y_true if int(idx) != -1]
|
||||
|
||||
seq_hat_text = "".join(seq_hat).replace(self.space, ' ')
|
||||
seq_hat_text = seq_hat_text.replace(self.blank, '')
|
||||
seq_true_text = "".join(seq_true).replace(self.space, ' ')
|
||||
|
||||
seqs_hat.append(seq_hat_text)
|
||||
seqs_true.append(seq_true_text)
|
||||
|
||||
return seqs_hat, seqs_true
|
||||
|
||||
def calculate_cer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level CER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level CER score
|
||||
|
||||
"""
|
||||
char_eds, char_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_chars = seq_hat_text.replace(' ', '')
|
||||
ref_chars = seq_true_text.replace(' ', '')
|
||||
|
||||
char_eds.append(editdistance.eval(hyp_chars, ref_chars))
|
||||
char_ref_lens.append(len(ref_chars))
|
||||
|
||||
return float(sum(char_eds)) / sum(char_ref_lens)
|
||||
|
||||
def calculate_wer(self, seqs_hat, seqs_true):
|
||||
"""Calculate sentence-level WER score for transducer model.
|
||||
|
||||
Args:
|
||||
seqs_hat (torch.Tensor): prediction (batch, seqlen)
|
||||
seqs_true (torch.Tensor): reference (batch, seqlen)
|
||||
|
||||
Returns:
|
||||
(float): average sentence-level WER score
|
||||
|
||||
"""
|
||||
word_eds, word_ref_lens = [], []
|
||||
|
||||
for i, seq_hat_text in enumerate(seqs_hat):
|
||||
seq_true_text = seqs_true[i]
|
||||
hyp_words = seq_hat_text.split()
|
||||
ref_words = seq_true_text.split()
|
||||
|
||||
word_eds.append(editdistance.eval(hyp_words, ref_words))
|
||||
word_ref_lens.append(len(ref_words))
|
||||
|
||||
return float(sum(word_eds)) / sum(word_ref_lens)
|
0
ppg_extractor/encoder/__init__.py
Normal file
0
ppg_extractor/encoder/__init__.py
Normal file
183
ppg_extractor/encoder/attention.py
Normal file
183
ppg_extractor/encoder/attention.py
Normal file
|
@ -0,0 +1,183 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Multi-Head Attention layer definition."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super(MultiHeadedAttention, self).__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(self, query, key, value):
|
||||
"""Transform query, key and value.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:return torch.Tensor transformed query, key and value
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(self, value, scores, mask):
|
||||
"""Compute attention context vector.
|
||||
|
||||
:param torch.Tensor value: (batch, head, time2, size)
|
||||
:param torch.Tensor scores: (batch, head, time1, time2)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:return torch.Tensor transformed `value` (batch, time1, d_model)
|
||||
weighted by the attention score (batch, time1, time2)
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
min_value = float(
|
||||
numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
|
||||
)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, query, key, value, mask):
|
||||
"""Compute 'Scaled Dot Product Attention'.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor mask: (batch, 1, time2) or (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask)
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding.
|
||||
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int n_head: the number of head s
|
||||
:param int n_feat: the number of features
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, n_head, n_feat, dropout_rate):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate)
|
||||
# linear transformation for positional ecoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x, zero_triu=False):
|
||||
"""Compute relative positinal encoding.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:param bool zero_triu: return the lower triangular part of the matrix
|
||||
"""
|
||||
zero_pad = torch.zeros((*x.size()[:3], 1), device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(*x.size()[:2], x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)
|
||||
|
||||
if zero_triu:
|
||||
ones = torch.ones((x.size(2), x.size(3)))
|
||||
x = x * torch.tril(ones, x.size(3) - x.size(2))[None, None, :, :]
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, query, key, value, pos_emb, mask):
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
|
||||
:param torch.Tensor query: (batch, time1, size)
|
||||
:param torch.Tensor key: (batch, time2, size)
|
||||
:param torch.Tensor value: (batch, time2, size)
|
||||
:param torch.Tensor pos_emb: (batch, time1, size)
|
||||
:param torch.Tensor mask: (batch, time1, time2)
|
||||
:param torch.nn.Dropout dropout:
|
||||
:return torch.Tensor: attention output (batch, time1, d_model)
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time2)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask)
|
262
ppg_extractor/encoder/conformer_encoder.py
Normal file
262
ppg_extractor/encoder/conformer_encoder.py
Normal file
|
@ -0,0 +1,262 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from typing import Callable
|
||||
from typing import Collection
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
from .convolution import ConvolutionModule
|
||||
from .encoder_layer import EncoderLayer
|
||||
from ..nets_utils import get_activation, make_pad_mask
|
||||
from .vgg import VGG2L
|
||||
from .attention import MultiHeadedAttention, RelPositionMultiHeadedAttention
|
||||
from .embedding import PositionalEncoding, ScaledPositionalEncoding, RelPositionalEncoding
|
||||
from .layer_norm import LayerNorm
|
||||
from .multi_layer_conv import Conv1dLinear, MultiLayeredConv1d
|
||||
from .positionwise_feed_forward import PositionwiseFeedForward
|
||||
from .repeat import repeat
|
||||
from .subsampling import Conv2dNoSubsampling, Conv2dSubsampling
|
||||
|
||||
|
||||
class ConformerEncoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
no_subsample=False,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super().__init__()
|
||||
|
||||
self._output_size = attention_dim
|
||||
idim = input_size
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
logging.info("Encoder input layer type: conv2d")
|
||||
if no_subsample:
|
||||
self.embed = Conv2dNoSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
subsample_by_2, # NOTE(Sx): added by songxiang
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
prev_states: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
xs_pad: input tensor (B, L, D)
|
||||
ilens: input lengths (B)
|
||||
prev_states: Not to be used now.
|
||||
Returns:
|
||||
Position embedded tensor and mask
|
||||
"""
|
||||
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
|
||||
|
||||
if isinstance(self.embed, (Conv2dSubsampling, Conv2dNoSubsampling, VGG2L)):
|
||||
# print(xs_pad.shape)
|
||||
xs_pad, masks = self.embed(xs_pad, masks)
|
||||
# print(xs_pad[0].size())
|
||||
else:
|
||||
xs_pad = self.embed(xs_pad)
|
||||
xs_pad, masks = self.encoders(xs_pad, masks)
|
||||
if isinstance(xs_pad, tuple):
|
||||
xs_pad = xs_pad[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
olens = masks.squeeze(1).sum(1)
|
||||
return xs_pad, olens, None
|
||||
|
||||
# def forward(self, xs, masks):
|
||||
# """Encode input sequence.
|
||||
|
||||
# :param torch.Tensor xs: input tensor
|
||||
# :param torch.Tensor masks: input mask
|
||||
# :return: position embedded tensor and mask
|
||||
# :rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
# """
|
||||
# if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
# xs, masks = self.embed(xs, masks)
|
||||
# else:
|
||||
# xs = self.embed(xs)
|
||||
|
||||
# xs, masks = self.encoders(xs, masks)
|
||||
# if isinstance(xs, tuple):
|
||||
# xs = xs[0]
|
||||
|
||||
# if self.normalize_before:
|
||||
# xs = self.after_norm(xs)
|
||||
# return xs, masks
|
74
ppg_extractor/encoder/convolution.py
Normal file
74
ppg_extractor/encoder/convolution.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""ConvolutionModule definition."""
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
|
||||
:param int channels: channels of cnn
|
||||
:param int kernel_size: kernerl size of cnn
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute convolution module.
|
||||
|
||||
:param torch.Tensor x: (batch, time, size)
|
||||
:return torch.Tensor: convoluted `value` (batch, time, d_model)
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x)
|
||||
|
||||
return x.transpose(1, 2)
|
166
ppg_extractor/encoder/embedding.py
Normal file
166
ppg_extractor/encoder/embedding.py
Normal file
|
@ -0,0 +1,166 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positonal Encoding Module."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _pre_hook(
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
strict,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
"""Perform pre-hook in load_state_dict for backward compatibility.
|
||||
|
||||
Note:
|
||||
We saved self.pe until v.0.5.2 but we have omitted it later.
|
||||
Therefore, we remove the item "pe" from `state_dict` for backward compatibility.
|
||||
|
||||
"""
|
||||
k = prefix + "pe"
|
||||
if k in state_dict:
|
||||
state_dict.pop(k)
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
:param reverse: whether to reverse the input position
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(PositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.reverse = reverse
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
self._register_load_state_dict_pre_hook(_pre_hook)
|
||||
|
||||
def extend_pe(self, x):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class ScaledPositionalEncoding(PositionalEncoding):
|
||||
"""Scaled positional encoding module.
|
||||
|
||||
See also: Sec. 3.2 https://arxiv.org/pdf/1809.08895.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len)
|
||||
self.alpha = torch.nn.Parameter(torch.tensor(1.0))
|
||||
|
||||
def reset_parameters(self):
|
||||
"""Reset parameters."""
|
||||
self.alpha.data = torch.tensor(1.0)
|
||||
|
||||
def forward(self, x):
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x + self.alpha * self.pe[:, : x.size(1)]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class RelPositionalEncoding(PositionalEncoding):
|
||||
"""Relitive positional encoding module.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, dropout_rate, max_len=5000):
|
||||
"""Initialize class.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
"""
|
||||
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||
|
||||
def forward(self, x):
|
||||
"""Compute positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
|
||||
Returns:
|
||||
torch.Tensor: x. Its shape is (batch, time, ...)
|
||||
torch.Tensor: pos_emb. Its shape is (1, time, ...)
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[:, : x.size(1)]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
217
ppg_extractor/encoder/encoder.py
Normal file
217
ppg_extractor/encoder/encoder.py
Normal file
|
@ -0,0 +1,217 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder definition."""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.conformer.convolution import ConvolutionModule
|
||||
from espnet.nets.pytorch_backend.conformer.encoder_layer import EncoderLayer
|
||||
from espnet.nets.pytorch_backend.nets_utils import get_activation
|
||||
from espnet.nets.pytorch_backend.transducer.vgg import VGG2L
|
||||
from espnet.nets.pytorch_backend.transformer.attention import (
|
||||
MultiHeadedAttention, # noqa: H301
|
||||
RelPositionMultiHeadedAttention, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import (
|
||||
PositionalEncoding, # noqa: H301
|
||||
ScaledPositionalEncoding, # noqa: H301
|
||||
RelPositionalEncoding, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.layer_norm import LayerNorm
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import Conv1dLinear
|
||||
from espnet.nets.pytorch_backend.transformer.multi_layer_conv import MultiLayeredConv1d
|
||||
from espnet.nets.pytorch_backend.transformer.positionwise_feed_forward import (
|
||||
PositionwiseFeedForward, # noqa: H301
|
||||
)
|
||||
from espnet.nets.pytorch_backend.transformer.repeat import repeat
|
||||
from espnet.nets.pytorch_backend.transformer.subsampling import Conv2dSubsampling
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Conformer encoder module.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int attention_dim: dimention of attention
|
||||
:param int attention_heads: the number of heads of multi head attention
|
||||
:param int linear_units: the number of units of position-wise feed forward
|
||||
:param int num_blocks: the number of decoder blocks
|
||||
:param float dropout_rate: dropout rate
|
||||
:param float attention_dropout_rate: dropout rate in attention
|
||||
:param float positional_dropout_rate: dropout rate after adding positional encoding
|
||||
:param str or torch.nn.Module input_layer: input layer type
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
:param str positionwise_layer_type: linear of conv1d
|
||||
:param int positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
|
||||
:param str encoder_pos_enc_layer_type: encoder positional encoding layer type
|
||||
:param str encoder_attn_layer_type: encoder attention layer type
|
||||
:param str activation_type: encoder activation function type
|
||||
:param bool macaron_style: whether to use macaron style for positionwise layer
|
||||
:param bool use_cnn_module: whether to use convolution module
|
||||
:param int cnn_module_kernel: kernerl size of convolution module
|
||||
:param int padding_idx: padding_idx for input_layer=embed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
positional_dropout_rate=0.1,
|
||||
attention_dropout_rate=0.0,
|
||||
input_layer="conv2d",
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
positionwise_layer_type="linear",
|
||||
positionwise_conv_kernel_size=1,
|
||||
macaron_style=False,
|
||||
pos_enc_layer_type="abs_pos",
|
||||
selfattention_layer_type="selfattn",
|
||||
activation_type="swish",
|
||||
use_cnn_module=False,
|
||||
cnn_module_kernel=31,
|
||||
padding_idx=-1,
|
||||
):
|
||||
"""Construct an Encoder object."""
|
||||
super(Encoder, self).__init__()
|
||||
|
||||
activation = get_activation(activation_type)
|
||||
if pos_enc_layer_type == "abs_pos":
|
||||
pos_enc_class = PositionalEncoding
|
||||
elif pos_enc_layer_type == "scaled_abs_pos":
|
||||
pos_enc_class = ScaledPositionalEncoding
|
||||
elif pos_enc_layer_type == "rel_pos":
|
||||
assert selfattention_layer_type == "rel_selfattn"
|
||||
pos_enc_class = RelPositionalEncoding
|
||||
else:
|
||||
raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
|
||||
|
||||
if input_layer == "linear":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, attention_dim),
|
||||
torch.nn.LayerNorm(attention_dim),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "conv2d":
|
||||
self.embed = Conv2dSubsampling(
|
||||
idim,
|
||||
attention_dim,
|
||||
dropout_rate,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer == "vgg2l":
|
||||
self.embed = VGG2L(idim, attention_dim)
|
||||
elif input_layer == "embed":
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Embedding(idim, attention_dim, padding_idx=padding_idx),
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif isinstance(input_layer, torch.nn.Module):
|
||||
self.embed = torch.nn.Sequential(
|
||||
input_layer,
|
||||
pos_enc_class(attention_dim, positional_dropout_rate),
|
||||
)
|
||||
elif input_layer is None:
|
||||
self.embed = torch.nn.Sequential(
|
||||
pos_enc_class(attention_dim, positional_dropout_rate)
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
self.normalize_before = normalize_before
|
||||
if positionwise_layer_type == "linear":
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d":
|
||||
positionwise_layer = MultiLayeredConv1d
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
elif positionwise_layer_type == "conv1d-linear":
|
||||
positionwise_layer = Conv1dLinear
|
||||
positionwise_layer_args = (
|
||||
attention_dim,
|
||||
linear_units,
|
||||
positionwise_conv_kernel_size,
|
||||
dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Support only linear or conv1d.")
|
||||
|
||||
if selfattention_layer_type == "selfattn":
|
||||
logging.info("encoder self-attention layer type = self-attention")
|
||||
encoder_selfattn_layer = MultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
elif selfattention_layer_type == "rel_selfattn":
|
||||
assert pos_enc_layer_type == "rel_pos"
|
||||
encoder_selfattn_layer = RelPositionMultiHeadedAttention
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
attention_dim,
|
||||
attention_dropout_rate,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
|
||||
|
||||
convolution_layer = ConvolutionModule
|
||||
convolution_layer_args = (attention_dim, cnn_module_kernel, activation)
|
||||
|
||||
self.encoders = repeat(
|
||||
num_blocks,
|
||||
lambda lnum: EncoderLayer(
|
||||
attention_dim,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args) if macaron_style else None,
|
||||
convolution_layer(*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
concat_after,
|
||||
),
|
||||
)
|
||||
if self.normalize_before:
|
||||
self.after_norm = LayerNorm(attention_dim)
|
||||
|
||||
def forward(self, xs, masks):
|
||||
"""Encode input sequence.
|
||||
|
||||
:param torch.Tensor xs: input tensor
|
||||
:param torch.Tensor masks: input mask
|
||||
:return: position embedded tensor and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
if isinstance(self.embed, (Conv2dSubsampling, VGG2L)):
|
||||
xs, masks = self.embed(xs, masks)
|
||||
else:
|
||||
xs = self.embed(xs)
|
||||
|
||||
xs, masks = self.encoders(xs, masks)
|
||||
if isinstance(xs, tuple):
|
||||
xs = xs[0]
|
||||
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
return xs, masks
|
152
ppg_extractor/encoder/encoder_layer.py
Normal file
152
ppg_extractor/encoder/encoder_layer.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
|
||||
from .layer_norm import LayerNorm
|
||||
|
||||
|
||||
class EncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
:param int size: input dim
|
||||
:param espnet.nets.pytorch_backend.transformer.attention.
|
||||
MultiHeadedAttention self_attn: self attention module
|
||||
RelPositionMultiHeadedAttention self_attn: self attention module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward.
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.transformer.positionwise_feed_forward
|
||||
for macaron style
|
||||
PositionwiseFeedForward feed_forward:
|
||||
feed forward module
|
||||
:param espnet.nets.pytorch_backend.conformer.convolution.
|
||||
ConvolutionModule feed_foreard:
|
||||
feed forward module
|
||||
:param float dropout_rate: dropout rate
|
||||
:param bool normalize_before: whether to use layer_norm before the first block
|
||||
:param bool concat_after: whether to concat attention layer's input and output
|
||||
if True, additional linear will be applied.
|
||||
i.e. x -> x + linear(concat(x, att(x)))
|
||||
if False, no additional linear will be applied. i.e. x -> x + att(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
feed_forward_macaron,
|
||||
conv_module,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayer, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = LayerNorm(size) # for the FNN module
|
||||
self.norm_mha = LayerNorm(size) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = LayerNorm(size)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = LayerNorm(size) # for the CNN module
|
||||
self.norm_final = LayerNorm(size) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
|
||||
def forward(self, x_input, mask, cache=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
:param torch.Tensor x_input: encoded source features, w/o pos_emb
|
||||
tuple((batch, max_time_in, size), (1, max_time_in, size))
|
||||
or (batch, max_time_in, size)
|
||||
:param torch.Tensor mask: mask for x (batch, max_time_in)
|
||||
:param torch.Tensor cache: cache for x (batch, max_time_in - 1, size)
|
||||
:rtype: Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
if isinstance(x_input, tuple):
|
||||
x, pos_emb = x_input[0], x_input[1]
|
||||
else:
|
||||
x, pos_emb = x_input, None
|
||||
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
if cache is None:
|
||||
x_q = x
|
||||
else:
|
||||
assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
|
||||
x_q = x[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
mask = None if mask is None else mask[:, -1:, :]
|
||||
|
||||
if pos_emb is not None:
|
||||
x_att = self.self_attn(x_q, x, x, pos_emb, mask)
|
||||
else:
|
||||
x_att = self.self_attn(x_q, x, x, mask)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat((x, x_att), dim=-1)
|
||||
x = residual + self.concat_linear(x_concat)
|
||||
else:
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x = residual + self.dropout(self.conv_module(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
if pos_emb is not None:
|
||||
return (x, pos_emb), mask
|
||||
|
||||
return x, mask
|
33
ppg_extractor/encoder/layer_norm.py
Normal file
33
ppg_extractor/encoder/layer_norm.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer normalization module."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Layer normalization module.
|
||||
|
||||
:param int nout: output dim size
|
||||
:param int dim: dimension to be normalized
|
||||
"""
|
||||
|
||||
def __init__(self, nout, dim=-1):
|
||||
"""Construct an LayerNorm object."""
|
||||
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
"""Apply layer normalization.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:return: layer normalized tensor
|
||||
:rtype torch.Tensor
|
||||
"""
|
||||
if self.dim == -1:
|
||||
return super(LayerNorm, self).forward(x)
|
||||
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|
105
ppg_extractor/encoder/multi_layer_conv.py
Normal file
105
ppg_extractor/encoder/multi_layer_conv.py
Normal file
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Tomoki Hayashi
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiLayeredConv1d(torch.nn.Module):
|
||||
"""Multi-layered conv1d for Transformer block.
|
||||
|
||||
This is a module of multi-leyered conv1d designed
|
||||
to replace positionwise feed-forward network
|
||||
in Transforner block, which is introduced in
|
||||
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
|
||||
|
||||
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
|
||||
https://arxiv.org/pdf/1905.09263.pdf
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize MultiLayeredConv1d module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(MultiLayeredConv1d, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Conv1d(
|
||||
hidden_chans,
|
||||
in_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
|
||||
|
||||
|
||||
class Conv1dLinear(torch.nn.Module):
|
||||
"""Conv1D + Linear for Transformer block.
|
||||
|
||||
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
|
||||
"""Initialize Conv1dLinear module.
|
||||
|
||||
Args:
|
||||
in_chans (int): Number of input channels.
|
||||
hidden_chans (int): Number of hidden channels.
|
||||
kernel_size (int): Kernel size of conv1d.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
super(Conv1dLinear, self).__init__()
|
||||
self.w_1 = torch.nn.Conv1d(
|
||||
in_chans,
|
||||
hidden_chans,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Batch of input tensors (B, ..., in_chans).
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of output tensors (B, ..., hidden_chans).
|
||||
|
||||
"""
|
||||
x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
|
||||
return self.w_2(self.dropout(x))
|
31
ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
31
ppg_extractor/encoder/positionwise_feed_forward.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Positionwise feed forward layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
:param int idim: input dimenstion
|
||||
:param int hidden_units: number of hidden units
|
||||
:param float dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward funciton."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
30
ppg_extractor/encoder/repeat.py
Normal file
30
ppg_extractor/encoder/repeat.py
Normal file
|
@ -0,0 +1,30 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Repeat the same layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class MultiSequential(torch.nn.Sequential):
|
||||
"""Multi-input multi-output torch.nn.Sequential."""
|
||||
|
||||
def forward(self, *args):
|
||||
"""Repeat."""
|
||||
for m in self:
|
||||
args = m(*args)
|
||||
return args
|
||||
|
||||
|
||||
def repeat(N, fn):
|
||||
"""Repeat module N times.
|
||||
|
||||
:param int N: repeat time
|
||||
:param function fn: function to generate module
|
||||
:return: repeated modules
|
||||
:rtype: MultiSequential
|
||||
"""
|
||||
return MultiSequential(*[fn(n) for n in range(N)])
|
218
ppg_extractor/encoder/subsampling.py
Normal file
218
ppg_extractor/encoder/subsampling.py
Normal file
|
@ -0,0 +1,218 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2019 Shigeki Karita
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Subsampling layer definition."""
|
||||
import logging
|
||||
import torch
|
||||
|
||||
from espnet.nets.pytorch_backend.transformer.embedding import PositionalEncoding
|
||||
|
||||
|
||||
class Conv2dSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length or 1/2 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None,
|
||||
subsample_by_2=False,
|
||||
):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling, self).__init__()
|
||||
self.subsample_by_2 = subsample_by_2
|
||||
if subsample_by_2:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 2), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
else:
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=4, stride=2, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (idim // 4), odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
if self.subsample_by_2:
|
||||
return x, x_mask[:, :, ::2]
|
||||
else:
|
||||
return x, x_mask[:, :, ::2][:, :, ::2]
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dNoSubsampling(torch.nn.Module):
|
||||
"""Convolutional 2D without subsampling.
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
:param torch.nn.Module pos_enc: custom position encoding layer
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate, pos_enc=None):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super().__init__()
|
||||
logging.info("Encoder does not do down-sample on mel-spectrogram.")
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, kernel_size=5, stride=1, padding=2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * idim, odim),
|
||||
pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask
|
||||
|
||||
def __getitem__(self, key):
|
||||
"""Subsample x.
|
||||
|
||||
When reset_parameters() is called, if use_scaled_pos_enc is used,
|
||||
return the positioning encoding.
|
||||
|
||||
"""
|
||||
if key != -1:
|
||||
raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
|
||||
return self.out[key]
|
||||
|
||||
|
||||
class Conv2dSubsampling6(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/6 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling6, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 5, 3),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-4:3]
|
||||
|
||||
|
||||
class Conv2dSubsampling8(torch.nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/8 length).
|
||||
|
||||
:param int idim: input dim
|
||||
:param int odim: output dim
|
||||
:param flaot dropout_rate: dropout rate
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, odim, dropout_rate):
|
||||
"""Construct an Conv2dSubsampling object."""
|
||||
super(Conv2dSubsampling8, self).__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
|
||||
PositionalEncoding(odim, dropout_rate),
|
||||
)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""Subsample x.
|
||||
|
||||
:param torch.Tensor x: input tensor
|
||||
:param torch.Tensor x_mask: input mask
|
||||
:return: subsampled x and mask
|
||||
:rtype Tuple[torch.Tensor, torch.Tensor]
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
|
18
ppg_extractor/encoder/swish.py
Normal file
18
ppg_extractor/encoder/swish.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""Swish() activation function for Conformer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x):
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
77
ppg_extractor/encoder/vgg.py
Normal file
77
ppg_extractor/encoder/vgg.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
"""VGG2L definition for transformer-transducer."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG2L module for transformer-transducer encoder."""
|
||||
|
||||
def __init__(self, idim, odim):
|
||||
"""Construct a VGG2L object.
|
||||
|
||||
Args:
|
||||
idim (int): dimension of inputs
|
||||
odim (int): dimension of outputs
|
||||
|
||||
"""
|
||||
super(VGG2L, self).__init__()
|
||||
|
||||
self.vgg2l = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(64, 64, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((3, 2)),
|
||||
torch.nn.Conv2d(64, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(128, 128, 3, stride=1, padding=1),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.MaxPool2d((2, 2)),
|
||||
)
|
||||
|
||||
self.output = torch.nn.Linear(128 * ((idim // 2) // 2), odim)
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
"""VGG2L forward for x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): input torch (B, T, idim)
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
|
||||
Returns:
|
||||
x (torch.Tensor): input torch (B, sub(T), attention_dim)
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.vgg2l(x)
|
||||
|
||||
b, c, t, f = x.size()
|
||||
|
||||
x = self.output(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
|
||||
if x_mask is None:
|
||||
return x, None
|
||||
else:
|
||||
x_mask = self.create_new_mask(x_mask, x)
|
||||
|
||||
return x, x_mask
|
||||
|
||||
def create_new_mask(self, x_mask, x):
|
||||
"""Create a subsampled version of x_mask.
|
||||
|
||||
Args:
|
||||
x_mask (torch.Tensor): (B, 1, T)
|
||||
x (torch.Tensor): (B, sub(T), attention_dim)
|
||||
|
||||
Returns:
|
||||
x_mask (torch.Tensor): (B, 1, sub(T))
|
||||
|
||||
"""
|
||||
x_t1 = x_mask.size(2) - (x_mask.size(2) % 3)
|
||||
x_mask = x_mask[:, :, :x_t1][:, :, ::3]
|
||||
|
||||
x_t2 = x_mask.size(2) - (x_mask.size(2) % 2)
|
||||
x_mask = x_mask[:, :, :x_t2][:, :, ::2]
|
||||
|
||||
return x_mask
|
298
ppg_extractor/encoders.py
Normal file
298
ppg_extractor/encoders.py
Normal file
|
@ -0,0 +1,298 @@
|
|||
import logging
|
||||
import six
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pack_padded_sequence
|
||||
from torch.nn.utils.rnn import pad_packed_sequence
|
||||
|
||||
from .e2e_asr_common import get_vgg2l_odim
|
||||
from .nets_utils import make_pad_mask, to_device
|
||||
|
||||
|
||||
class RNNP(torch.nn.Module):
|
||||
"""RNN with projection layer module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of projection units
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, subsample, dropout, typ="blstm"):
|
||||
super(RNNP, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
for i in six.moves.range(elayers):
|
||||
if i == 0:
|
||||
inputdim = idim
|
||||
else:
|
||||
inputdim = hdim
|
||||
rnn = torch.nn.LSTM(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir,
|
||||
batch_first=True) if "lstm" in typ \
|
||||
else torch.nn.GRU(inputdim, cdim, dropout=dropout, num_layers=1, bidirectional=bidir, batch_first=True)
|
||||
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
|
||||
# bottleneck layer to merge
|
||||
if bidir:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(2 * cdim, hdim))
|
||||
else:
|
||||
setattr(self, "bt%d" % i, torch.nn.Linear(cdim, hdim))
|
||||
|
||||
self.elayers = elayers
|
||||
self.cdim = cdim
|
||||
self.subsample = subsample
|
||||
self.typ = typ
|
||||
self.bidir = bidir
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNNP forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, hdim)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
elayer_states = []
|
||||
for layer in six.moves.range(self.elayers):
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False)
|
||||
rnn = getattr(self, ("birnn" if self.bidir else "rnn") + str(layer))
|
||||
rnn.flatten_parameters()
|
||||
if prev_state is not None and rnn.bidirectional:
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
|
||||
elayer_states.append(states)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
sub = self.subsample[layer + 1]
|
||||
if sub > 1:
|
||||
ys_pad = ys_pad[:, ::sub]
|
||||
ilens = [int(i + 1) // sub for i in ilens]
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = getattr(self, 'bt' + str(layer)
|
||||
)(ys_pad.contiguous().view(-1, ys_pad.size(2)))
|
||||
if layer == self.elayers - 1:
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
else:
|
||||
xs_pad = torch.tanh(projected.view(ys_pad.size(0), ys_pad.size(1), -1))
|
||||
|
||||
return xs_pad, ilens, elayer_states # x: utt list of frame x dim
|
||||
|
||||
|
||||
class RNN(torch.nn.Module):
|
||||
"""RNN module
|
||||
|
||||
:param int idim: dimension of inputs
|
||||
:param int elayers: number of encoder layers
|
||||
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
|
||||
:param int hdim: number of final projection units
|
||||
:param float dropout: dropout rate
|
||||
:param str typ: The RNN type
|
||||
"""
|
||||
|
||||
def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
|
||||
super(RNN, self).__init__()
|
||||
bidir = typ[0] == "b"
|
||||
self.nbrnn = torch.nn.LSTM(idim, cdim, elayers, batch_first=True,
|
||||
dropout=dropout, bidirectional=bidir) if "lstm" in typ \
|
||||
else torch.nn.GRU(idim, cdim, elayers, batch_first=True, dropout=dropout,
|
||||
bidirectional=bidir)
|
||||
if bidir:
|
||||
self.l_last = torch.nn.Linear(cdim * 2, hdim)
|
||||
else:
|
||||
self.l_last = torch.nn.Linear(cdim, hdim)
|
||||
self.typ = typ
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_state=None):
|
||||
"""RNN forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous RNN states
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
xs_pack = pack_padded_sequence(xs_pad, ilens, batch_first=True)
|
||||
self.nbrnn.flatten_parameters()
|
||||
if prev_state is not None and self.nbrnn.bidirectional:
|
||||
# We assume that when previous state is passed, it means that we're streaming the input
|
||||
# and therefore cannot propagate backward BRNN state (otherwise it goes in the wrong direction)
|
||||
prev_state = reset_backward_rnn_state(prev_state)
|
||||
ys, states = self.nbrnn(xs_pack, hx=prev_state)
|
||||
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
|
||||
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
|
||||
# (sum _utt frame_utt) x dim
|
||||
projected = torch.tanh(self.l_last(
|
||||
ys_pad.contiguous().view(-1, ys_pad.size(2))))
|
||||
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
|
||||
return xs_pad, ilens, states # x: utt list of frame x dim
|
||||
|
||||
|
||||
def reset_backward_rnn_state(states):
|
||||
"""Sets backward BRNN states to zeroes - useful in processing of sliding windows over the inputs"""
|
||||
if isinstance(states, (list, tuple)):
|
||||
for state in states:
|
||||
state[1::2] = 0.
|
||||
else:
|
||||
states[1::2] = 0.
|
||||
return states
|
||||
|
||||
|
||||
class VGG2L(torch.nn.Module):
|
||||
"""VGG-like module
|
||||
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, in_channel=1, downsample=True):
|
||||
super(VGG2L, self).__init__()
|
||||
# CNN layer (VGG motivated)
|
||||
self.conv1_1 = torch.nn.Conv2d(in_channel, 64, 3, stride=1, padding=1)
|
||||
self.conv1_2 = torch.nn.Conv2d(64, 64, 3, stride=1, padding=1)
|
||||
self.conv2_1 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
|
||||
self.conv2_2 = torch.nn.Conv2d(128, 128, 3, stride=1, padding=1)
|
||||
|
||||
self.in_channel = in_channel
|
||||
self.downsample = downsample
|
||||
if downsample:
|
||||
self.stride = 2
|
||||
else:
|
||||
self.stride = 1
|
||||
|
||||
def forward(self, xs_pad, ilens, **kwargs):
|
||||
"""VGG2L forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4) if downsample
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
logging.debug(self.__class__.__name__ + ' input lengths: ' + str(ilens))
|
||||
|
||||
# x: utt x frame x dim
|
||||
# xs_pad = F.pad_sequence(xs_pad)
|
||||
|
||||
# x: utt x 1 (input channel num) x frame x dim
|
||||
xs_pad = xs_pad.view(xs_pad.size(0), xs_pad.size(1), self.in_channel,
|
||||
xs_pad.size(2) // self.in_channel).transpose(1, 2)
|
||||
|
||||
# NOTE: max_pool1d ?
|
||||
xs_pad = F.relu(self.conv1_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv1_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
|
||||
xs_pad = F.relu(self.conv2_1(xs_pad))
|
||||
xs_pad = F.relu(self.conv2_2(xs_pad))
|
||||
if self.downsample:
|
||||
xs_pad = F.max_pool2d(xs_pad, 2, stride=self.stride, ceil_mode=True)
|
||||
if torch.is_tensor(ilens):
|
||||
ilens = ilens.cpu().numpy()
|
||||
else:
|
||||
ilens = np.array(ilens, dtype=np.float32)
|
||||
if self.downsample:
|
||||
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
|
||||
ilens = np.array(
|
||||
np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
|
||||
|
||||
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
|
||||
xs_pad = xs_pad.transpose(1, 2)
|
||||
xs_pad = xs_pad.contiguous().view(
|
||||
xs_pad.size(0), xs_pad.size(1), xs_pad.size(2) * xs_pad.size(3))
|
||||
return xs_pad, ilens, None # no state in this layer
|
||||
|
||||
|
||||
class Encoder(torch.nn.Module):
|
||||
"""Encoder module
|
||||
|
||||
:param str etype: type of encoder network
|
||||
:param int idim: number of dimensions of encoder network
|
||||
:param int elayers: number of layers of encoder network
|
||||
:param int eunits: number of lstm units of encoder network
|
||||
:param int eprojs: number of projection units of encoder network
|
||||
:param np.ndarray subsample: list of subsampling numbers
|
||||
:param float dropout: dropout rate
|
||||
:param int in_channel: number of input channels
|
||||
"""
|
||||
|
||||
def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
|
||||
super(Encoder, self).__init__()
|
||||
typ = etype.lstrip("vgg").rstrip("p")
|
||||
if typ not in ['lstm', 'gru', 'blstm', 'bgru']:
|
||||
logging.error("Error: need to specify an appropriate encoder architecture")
|
||||
|
||||
if etype.startswith("vgg"):
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNNP(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
subsample, dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + 'P for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([VGG2L(in_channel),
|
||||
RNN(get_vgg2l_odim(idim, in_channel=in_channel), elayers, eunits,
|
||||
eprojs,
|
||||
dropout, typ=typ)])
|
||||
logging.info('Use CNN-VGG + ' + typ.upper() + ' for encoder')
|
||||
else:
|
||||
if etype[-1] == "p":
|
||||
self.enc = torch.nn.ModuleList(
|
||||
[RNNP(idim, elayers, eunits, eprojs, subsample, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' with every-layer projection for encoder')
|
||||
else:
|
||||
self.enc = torch.nn.ModuleList([RNN(idim, elayers, eunits, eprojs, dropout, typ=typ)])
|
||||
logging.info(typ.upper() + ' without projection for encoder')
|
||||
|
||||
def forward(self, xs_pad, ilens, prev_states=None):
|
||||
"""Encoder forward
|
||||
|
||||
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
|
||||
:param torch.Tensor ilens: batch of lengths of input sequences (B)
|
||||
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
|
||||
:return: batch of hidden state sequences (B, Tmax, eprojs)
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
if prev_states is None:
|
||||
prev_states = [None] * len(self.enc)
|
||||
assert len(prev_states) == len(self.enc)
|
||||
|
||||
current_states = []
|
||||
for module, prev_state in zip(self.enc, prev_states):
|
||||
xs_pad, ilens, states = module(xs_pad, ilens, prev_state=prev_state)
|
||||
current_states.append(states)
|
||||
|
||||
# make mask to remove bias value in padded part
|
||||
mask = to_device(self, make_pad_mask(ilens).unsqueeze(-1))
|
||||
|
||||
return xs_pad.masked_fill(mask, 0.0), ilens, current_states
|
||||
|
||||
|
||||
def encoder_for(args, idim, subsample):
|
||||
"""Instantiates an encoder module given the program arguments
|
||||
|
||||
:param Namespace args: The arguments
|
||||
:param int or List of integer idim: dimension of input, e.g. 83, or
|
||||
List of dimensions of inputs, e.g. [83,83]
|
||||
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
|
||||
List of subsample factors of each encoder. e.g. [[1,2,2,1,1], [1,2,2,1,1]]
|
||||
:rtype torch.nn.Module
|
||||
:return: The encoder module
|
||||
"""
|
||||
num_encs = getattr(args, "num_encs", 1) # use getattr to keep compatibility
|
||||
if num_encs == 1:
|
||||
# compatible with single encoder asr mode
|
||||
return Encoder(args.etype, idim, args.elayers, args.eunits, args.eprojs, subsample, args.dropout_rate)
|
||||
elif num_encs >= 1:
|
||||
enc_list = torch.nn.ModuleList()
|
||||
for idx in range(num_encs):
|
||||
enc = Encoder(args.etype[idx], idim[idx], args.elayers[idx], args.eunits[idx], args.eprojs, subsample[idx],
|
||||
args.dropout_rate[idx])
|
||||
enc_list.append(enc)
|
||||
return enc_list
|
||||
else:
|
||||
raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
|
115
ppg_extractor/frontend.py
Normal file
115
ppg_extractor/frontend.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
import copy
|
||||
from typing import Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
from .log_mel import LogMel
|
||||
from .stft import Stft
|
||||
|
||||
|
||||
class DefaultFrontend(torch.nn.Module):
|
||||
"""Conventional frontend structure for ASR
|
||||
|
||||
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: 16000,
|
||||
n_fft: int = 1024,
|
||||
win_length: int = 800,
|
||||
hop_length: int = 160,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
n_mels: int = 80,
|
||||
fmin: int = None,
|
||||
fmax: int = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
frontend_conf=None, #Optional[dict] = get_default_kwargs(Frontend),
|
||||
kaldi_padding_mode=False,
|
||||
downsample_rate: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.downsample_rate = downsample_rate
|
||||
|
||||
# Deepcopy (In general, dict shouldn't be used as default arg)
|
||||
frontend_conf = copy.deepcopy(frontend_conf)
|
||||
|
||||
self.stft = Stft(
|
||||
n_fft=n_fft,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
center=center,
|
||||
pad_mode=pad_mode,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
kaldi_padding_mode=kaldi_padding_mode
|
||||
)
|
||||
if frontend_conf is not None:
|
||||
self.frontend = Frontend(idim=n_fft // 2 + 1, **frontend_conf)
|
||||
else:
|
||||
self.frontend = None
|
||||
|
||||
self.logmel = LogMel(
|
||||
fs=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm,
|
||||
)
|
||||
self.n_mels = n_mels
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, input_lengths: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Domain-conversion: e.g. Stft: time -> time-freq
|
||||
input_stft, feats_lens = self.stft(input, input_lengths)
|
||||
|
||||
assert input_stft.dim() >= 4, input_stft.shape
|
||||
# "2" refers to the real/imag parts of Complex
|
||||
assert input_stft.shape[-1] == 2, input_stft.shape
|
||||
|
||||
# Change torch.Tensor to ComplexTensor
|
||||
# input_stft: (..., F, 2) -> (..., F)
|
||||
input_stft = ComplexTensor(input_stft[..., 0], input_stft[..., 1])
|
||||
|
||||
# 2. [Option] Speech enhancement
|
||||
if self.frontend is not None:
|
||||
assert isinstance(input_stft, ComplexTensor), type(input_stft)
|
||||
# input_stft: (Batch, Length, [Channel], Freq)
|
||||
input_stft, _, mask = self.frontend(input_stft, feats_lens)
|
||||
|
||||
# 3. [Multi channel case]: Select a channel
|
||||
if input_stft.dim() == 4:
|
||||
# h: (B, T, C, F) -> h: (B, T, F)
|
||||
if self.training:
|
||||
# Select 1ch randomly
|
||||
ch = np.random.randint(input_stft.size(2))
|
||||
input_stft = input_stft[:, :, ch, :]
|
||||
else:
|
||||
# Use the first channel
|
||||
input_stft = input_stft[:, :, 0, :]
|
||||
|
||||
# 4. STFT -> Power spectrum
|
||||
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
|
||||
input_power = input_stft.real ** 2 + input_stft.imag ** 2
|
||||
|
||||
# 5. Feature transform e.g. Stft -> Log-Mel-Fbank
|
||||
# input_power: (Batch, [Channel,] Length, Freq)
|
||||
# -> input_feats: (Batch, Length, Dim)
|
||||
input_feats, _ = self.logmel(input_power, feats_lens)
|
||||
|
||||
# NOTE(sx): pad
|
||||
max_len = input_feats.size(1)
|
||||
if self.downsample_rate > 1 and max_len % self.downsample_rate != 0:
|
||||
padding = self.downsample_rate - max_len % self.downsample_rate
|
||||
# print("Logmel: ", input_feats.size())
|
||||
input_feats = torch.nn.functional.pad(input_feats, (0, 0, 0, padding),
|
||||
"constant", 0)
|
||||
# print("Logmel(after padding): ",input_feats.size())
|
||||
feats_lens[torch.argmax(feats_lens)] = max_len + padding
|
||||
|
||||
return input_feats, feats_lens
|
74
ppg_extractor/log_mel.py
Normal file
74
ppg_extractor/log_mel.py
Normal file
|
@ -0,0 +1,74 @@
|
|||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Tuple
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class LogMel(torch.nn.Module):
|
||||
"""Convert STFT to fbank feats
|
||||
|
||||
The arguments is same as librosa.filters.mel
|
||||
|
||||
Args:
|
||||
fs: number > 0 [scalar] sampling rate of the incoming signal
|
||||
n_fft: int > 0 [scalar] number of FFT components
|
||||
n_mels: int > 0 [scalar] number of Mel bands to generate
|
||||
fmin: float >= 0 [scalar] lowest frequency (in Hz)
|
||||
fmax: float >= 0 [scalar] highest frequency (in Hz).
|
||||
If `None`, use `fmax = fs / 2.0`
|
||||
htk: use HTK formula instead of Slaney
|
||||
norm: {None, 1, np.inf} [scalar]
|
||||
if 1, divide the triangular mel weights by the width of the mel band
|
||||
(area normalization). Otherwise, leave all the triangles aiming for
|
||||
a peak value of 1.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
n_fft: int = 512,
|
||||
n_mels: int = 80,
|
||||
fmin: float = None,
|
||||
fmax: float = None,
|
||||
htk: bool = False,
|
||||
norm=1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
_mel_options = dict(
|
||||
sr=fs, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax, htk=htk, norm=norm
|
||||
)
|
||||
self.mel_options = _mel_options
|
||||
|
||||
# Note(kamo): The mel matrix of librosa is different from kaldi.
|
||||
melmat = librosa.filters.mel(**_mel_options)
|
||||
# melmat: (D2, D1) -> (D1, D2)
|
||||
self.register_buffer("melmat", torch.from_numpy(melmat.T).float())
|
||||
inv_mel = np.linalg.pinv(melmat)
|
||||
self.register_buffer("inv_melmat", torch.from_numpy(inv_mel.T).float())
|
||||
|
||||
def extra_repr(self):
|
||||
return ", ".join(f"{k}={v}" for k, v in self.mel_options.items())
|
||||
|
||||
def forward(
|
||||
self, feat: torch.Tensor, ilens: torch.Tensor = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
|
||||
mel_feat = torch.matmul(feat, self.melmat)
|
||||
|
||||
logmel_feat = (mel_feat + 1e-20).log()
|
||||
# Zero padding
|
||||
if ilens is not None:
|
||||
logmel_feat = logmel_feat.masked_fill(
|
||||
make_pad_mask(ilens, logmel_feat, 1), 0.0
|
||||
)
|
||||
else:
|
||||
ilens = feat.new_full(
|
||||
[feat.size(0)], fill_value=feat.size(1), dtype=torch.long
|
||||
)
|
||||
return logmel_feat, ilens
|
465
ppg_extractor/nets_utils.py
Normal file
465
ppg_extractor/nets_utils.py
Normal file
|
@ -0,0 +1,465 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""Network related utility tools."""
|
||||
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def to_device(m, x):
|
||||
"""Send tensor into the device of the module.
|
||||
|
||||
Args:
|
||||
m (torch.nn.Module): Torch module.
|
||||
x (Tensor): Torch tensor.
|
||||
|
||||
Returns:
|
||||
Tensor: Torch tensor located in the same place as torch module.
|
||||
|
||||
"""
|
||||
assert isinstance(m, torch.nn.Module)
|
||||
device = next(m.parameters()).device
|
||||
return x.to(device)
|
||||
|
||||
|
||||
def pad_list(xs, pad_value):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
n_batch = len(xs)
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
|
||||
for i in range(n_batch):
|
||||
pad[i, :xs[i].size(0)] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0],
|
||||
[0, 0, 0, 0]],
|
||||
[[0, 0, 0, 1],
|
||||
[0, 0, 0, 1]],
|
||||
[[0, 0, 1, 1],
|
||||
[0, 0, 1, 1]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_pad_mask(lengths, xs)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_pad_mask(lengths, xs, 1)
|
||||
tensor([[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]],
|
||||
[[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
>>> make_pad_mask(lengths, xs, 2)
|
||||
tensor([[[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 0, 1]],
|
||||
[[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1],
|
||||
[0, 0, 0, 1, 1, 1]],
|
||||
[[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1],
|
||||
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
if length_dim == 0:
|
||||
raise ValueError('length_dim cannot be 0: {}'.format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(slice(None) if i in (0, length_dim) else None
|
||||
for i in range(xs.dim()))
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
|
||||
|
||||
def make_non_pad_mask(lengths, xs=None, length_dim=-1):
|
||||
"""Make mask tensor containing indices of non-padded part.
|
||||
|
||||
Args:
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
xs (Tensor, optional): The reference tensor. If set, masks will be the same shape as this tensor.
|
||||
length_dim (int, optional): Dimension indicator of the above tensor. See the example.
|
||||
|
||||
Returns:
|
||||
ByteTensor: mask tensor containing indices of padded part.
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
|
||||
|
||||
Examples:
|
||||
With only lengths.
|
||||
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_non_pad_mask(lengths)
|
||||
masks = [[1, 1, 1, 1 ,1],
|
||||
[1, 1, 1, 0, 0],
|
||||
[1, 1, 0, 0, 0]]
|
||||
|
||||
With the reference tensor.
|
||||
|
||||
>>> xs = torch.zeros((3, 2, 4))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]],
|
||||
[[1, 1, 1, 0],
|
||||
[1, 1, 1, 0]],
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0]]], dtype=torch.uint8)
|
||||
>>> xs = torch.zeros((3, 2, 6))
|
||||
>>> make_non_pad_mask(lengths, xs)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
With the reference tensor and dimension indicator.
|
||||
|
||||
>>> xs = torch.zeros((3, 6, 6))
|
||||
>>> make_non_pad_mask(lengths, xs, 1)
|
||||
tensor([[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]],
|
||||
[[1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0],
|
||||
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
>>> make_non_pad_mask(lengths, xs, 2)
|
||||
tensor([[[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0],
|
||||
[1, 1, 1, 1, 1, 0]],
|
||||
[[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0],
|
||||
[1, 1, 1, 0, 0, 0]],
|
||||
[[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0],
|
||||
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
|
||||
|
||||
"""
|
||||
return ~make_pad_mask(lengths, xs, length_dim)
|
||||
|
||||
|
||||
def mask_by_length(xs, lengths, fill=0):
|
||||
"""Mask tensor according to length.
|
||||
|
||||
Args:
|
||||
xs (Tensor): Batch of input tensor (B, `*`).
|
||||
lengths (LongTensor or List): Batch of lengths (B,).
|
||||
fill (int or float): Value to fill masked part.
|
||||
|
||||
Returns:
|
||||
Tensor: Batch of masked input tensor (B, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = torch.arange(5).repeat(3, 1) + 1
|
||||
>>> x
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 4, 5]])
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> mask_by_length(x, lengths)
|
||||
tensor([[1, 2, 3, 4, 5],
|
||||
[1, 2, 3, 0, 0],
|
||||
[1, 2, 0, 0, 0]])
|
||||
|
||||
"""
|
||||
assert xs.size(0) == len(lengths)
|
||||
ret = xs.data.new(*xs.size()).fill_(fill)
|
||||
for i, l in enumerate(lengths):
|
||||
ret[i, :l] = xs[i, :l]
|
||||
return ret
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs, pad_targets, ignore_label):
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
float: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(
|
||||
pad_targets.size(0),
|
||||
pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return float(numerator) / float(denominator)
|
||||
|
||||
|
||||
def to_torch_tensor(x):
|
||||
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
|
||||
|
||||
Args:
|
||||
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
|
||||
|
||||
Returns:
|
||||
Tensor or ComplexTensor: Type converted inputs.
|
||||
|
||||
Examples:
|
||||
>>> xs = np.ones(3, dtype=np.float32)
|
||||
>>> xs = to_torch_tensor(xs)
|
||||
tensor([1., 1., 1.])
|
||||
>>> xs = torch.ones(3, 4, 5)
|
||||
>>> assert to_torch_tensor(xs) is xs
|
||||
>>> xs = {'real': xs, 'imag': xs}
|
||||
>>> to_torch_tensor(xs)
|
||||
ComplexTensor(
|
||||
Real:
|
||||
tensor([1., 1., 1.])
|
||||
Imag;
|
||||
tensor([1., 1., 1.])
|
||||
)
|
||||
|
||||
"""
|
||||
# If numpy, change to torch tensor
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype.kind == 'c':
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
return ComplexTensor(x)
|
||||
else:
|
||||
return torch.from_numpy(x)
|
||||
|
||||
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
|
||||
elif isinstance(x, dict):
|
||||
# Dynamically importing because torch_complex requires python3
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
|
||||
if 'real' not in x or 'imag' not in x:
|
||||
raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
|
||||
# Relative importing because of using python3 syntax
|
||||
return ComplexTensor(x['real'], x['imag'])
|
||||
|
||||
# If torch.Tensor, as it is
|
||||
elif isinstance(x, torch.Tensor):
|
||||
return x
|
||||
|
||||
else:
|
||||
error = ("x must be numpy.ndarray, torch.Tensor or a dict like "
|
||||
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
|
||||
"but got {}".format(type(x)))
|
||||
try:
|
||||
from torch_complex.tensor import ComplexTensor
|
||||
except Exception:
|
||||
# If PY2
|
||||
raise ValueError(error)
|
||||
else:
|
||||
# If PY3
|
||||
if isinstance(x, ComplexTensor):
|
||||
return x
|
||||
else:
|
||||
raise ValueError(error)
|
||||
|
||||
|
||||
def get_subsample(train_args, mode, arch):
|
||||
"""Parse the subsampling factors from the training args for the specified `mode` and `arch`.
|
||||
|
||||
Args:
|
||||
train_args: argument Namespace containing options.
|
||||
mode: one of ('asr', 'mt', 'st')
|
||||
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
|
||||
|
||||
Returns:
|
||||
np.ndarray / List[np.ndarray]: subsampling factors.
|
||||
"""
|
||||
if arch == 'transformer':
|
||||
return np.array([1])
|
||||
|
||||
elif mode == 'mt' and arch == 'rnn':
|
||||
# +1 means input (+1) and layers outputs (train_args.elayer)
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
logging.warning('Subsampling is not performed for machine translation.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif (mode == 'asr' and arch in ('rnn', 'rnn-t')) or \
|
||||
(mode == 'mt' and arch == 'rnn') or \
|
||||
(mode == 'st' and arch == 'rnn'):
|
||||
subsample = np.ones(train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mix':
|
||||
subsample = np.ones(train_args.elayers_sd + train_args.elayers + 1, dtype=np.int)
|
||||
if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
|
||||
ss = train_args.subsample.split("_")
|
||||
for j in range(min(train_args.elayers_sd + train_args.elayers + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.')
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
return subsample
|
||||
|
||||
elif mode == 'asr' and arch == 'rnn_mulenc':
|
||||
subsample_list = []
|
||||
for idx in range(train_args.num_encs):
|
||||
subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
|
||||
if train_args.etype[idx].endswith("p") and not train_args.etype[idx].startswith("vgg"):
|
||||
ss = train_args.subsample[idx].split("_")
|
||||
for j in range(min(train_args.elayers[idx] + 1, len(ss))):
|
||||
subsample[j] = int(ss[j])
|
||||
else:
|
||||
logging.warning(
|
||||
'Encoder %d: Subsampling is not performed for vgg*. '
|
||||
'It is performed in max pooling layers at CNN.', idx + 1)
|
||||
logging.info('subsample: ' + ' '.join([str(x) for x in subsample]))
|
||||
subsample_list.append(subsample)
|
||||
return subsample_list
|
||||
|
||||
else:
|
||||
raise ValueError('Invalid options: mode={}, arch={}'.format(mode, arch))
|
||||
|
||||
|
||||
def rename_state_dict(old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]):
|
||||
"""Replace keys of old prefix with new prefix in state dict."""
|
||||
# need this list not to break the dict iterator
|
||||
old_keys = [k for k in state_dict if k.startswith(old_prefix)]
|
||||
if len(old_keys) > 0:
|
||||
logging.warning(f'Rename: {old_prefix} -> {new_prefix}')
|
||||
for k in old_keys:
|
||||
v = state_dict.pop(k)
|
||||
new_k = k.replace(old_prefix, new_prefix)
|
||||
state_dict[new_k] = v
|
||||
|
||||
def get_activation(act):
|
||||
"""Return activation function."""
|
||||
# Lazy load to avoid unused import
|
||||
from .encoder.swish import Swish
|
||||
|
||||
activation_funcs = {
|
||||
"hardtanh": torch.nn.Hardtanh,
|
||||
"relu": torch.nn.ReLU,
|
||||
"selu": torch.nn.SELU,
|
||||
"swish": Swish,
|
||||
}
|
||||
|
||||
return activation_funcs[act]()
|
118
ppg_extractor/stft.py
Normal file
118
ppg_extractor/stft.py
Normal file
|
@ -0,0 +1,118 @@
|
|||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class Stft(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_fft: int = 512,
|
||||
win_length: Union[int, None] = 512,
|
||||
hop_length: int = 128,
|
||||
center: bool = True,
|
||||
pad_mode: str = "reflect",
|
||||
normalized: bool = False,
|
||||
onesided: bool = True,
|
||||
kaldi_padding_mode=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
if win_length is None:
|
||||
self.win_length = n_fft
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_length = hop_length
|
||||
self.center = center
|
||||
self.pad_mode = pad_mode
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
self.kaldi_padding_mode = kaldi_padding_mode
|
||||
if self.kaldi_padding_mode:
|
||||
self.win_length = 400
|
||||
|
||||
def extra_repr(self):
|
||||
return (
|
||||
f"n_fft={self.n_fft}, "
|
||||
f"win_length={self.win_length}, "
|
||||
f"hop_length={self.hop_length}, "
|
||||
f"center={self.center}, "
|
||||
f"pad_mode={self.pad_mode}, "
|
||||
f"normalized={self.normalized}, "
|
||||
f"onesided={self.onesided}"
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""STFT forward function.
|
||||
|
||||
Args:
|
||||
input: (Batch, Nsamples) or (Batch, Nsample, Channels)
|
||||
ilens: (Batch)
|
||||
Returns:
|
||||
output: (Batch, Frames, Freq, 2) or (Batch, Frames, Channels, Freq, 2)
|
||||
|
||||
"""
|
||||
bs = input.size(0)
|
||||
if input.dim() == 3:
|
||||
multi_channel = True
|
||||
# input: (Batch, Nsample, Channels) -> (Batch * Channels, Nsample)
|
||||
input = input.transpose(1, 2).reshape(-1, input.size(1))
|
||||
else:
|
||||
multi_channel = False
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# or (Batch, Channel, Freq, Frames, 2=real_imag)
|
||||
if not self.kaldi_padding_mode:
|
||||
output = torch.stft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=self.center,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
return_complex=False
|
||||
)
|
||||
else:
|
||||
# NOTE(sx): Use Kaldi-fasion padding, maybe wrong
|
||||
num_pads = self.n_fft - self.win_length
|
||||
input = torch.nn.functional.pad(input, (num_pads, 0))
|
||||
output = torch.stft(
|
||||
input,
|
||||
n_fft=self.n_fft,
|
||||
win_length=self.win_length,
|
||||
hop_length=self.hop_length,
|
||||
center=False,
|
||||
pad_mode=self.pad_mode,
|
||||
normalized=self.normalized,
|
||||
onesided=self.onesided,
|
||||
return_complex=False
|
||||
)
|
||||
|
||||
# output: (Batch, Freq, Frames, 2=real_imag)
|
||||
# -> (Batch, Frames, Freq, 2=real_imag)
|
||||
output = output.transpose(1, 2)
|
||||
if multi_channel:
|
||||
# output: (Batch * Channel, Frames, Freq, 2=real_imag)
|
||||
# -> (Batch, Frame, Channel, Freq, 2=real_imag)
|
||||
output = output.view(bs, -1, output.size(1), output.size(2), 2).transpose(
|
||||
1, 2
|
||||
)
|
||||
|
||||
if ilens is not None:
|
||||
if self.center:
|
||||
pad = self.win_length // 2
|
||||
ilens = ilens + 2 * pad
|
||||
olens = torch.div(ilens - self.win_length, self.hop_length, rounding_mode='floor') + 1
|
||||
# olens = ilens - self.win_length // self.hop_length + 1
|
||||
output.masked_fill_(make_pad_mask(olens, output, 1), 0.0)
|
||||
else:
|
||||
olens = None
|
||||
|
||||
return output, olens
|
82
ppg_extractor/utterance_mvn.py
Normal file
82
ppg_extractor/utterance_mvn.py
Normal file
|
@ -0,0 +1,82 @@
|
|||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from .nets_utils import make_pad_mask
|
||||
|
||||
|
||||
class UtteranceMVN(torch.nn.Module):
|
||||
def __init__(
|
||||
self, norm_means: bool = True, norm_vars: bool = False, eps: float = 1.0e-20,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm_means = norm_means
|
||||
self.norm_vars = norm_vars
|
||||
self.eps = eps
|
||||
|
||||
def extra_repr(self):
|
||||
return f"norm_means={self.norm_means}, norm_vars={self.norm_vars}"
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, ilens: torch.Tensor = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Forward function
|
||||
|
||||
Args:
|
||||
x: (B, L, ...)
|
||||
ilens: (B,)
|
||||
|
||||
"""
|
||||
return utterance_mvn(
|
||||
x,
|
||||
ilens,
|
||||
norm_means=self.norm_means,
|
||||
norm_vars=self.norm_vars,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
|
||||
def utterance_mvn(
|
||||
x: torch.Tensor,
|
||||
ilens: torch.Tensor = None,
|
||||
norm_means: bool = True,
|
||||
norm_vars: bool = False,
|
||||
eps: float = 1.0e-20,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Apply utterance mean and variance normalization
|
||||
|
||||
Args:
|
||||
x: (B, T, D), assumed zero padded
|
||||
ilens: (B,)
|
||||
norm_means:
|
||||
norm_vars:
|
||||
eps:
|
||||
|
||||
"""
|
||||
if ilens is None:
|
||||
ilens = x.new_full([x.size(0)], x.size(1))
|
||||
ilens_ = ilens.to(x.device, x.dtype).view(-1, *[1 for _ in range(x.dim() - 1)])
|
||||
# Zero padding
|
||||
if x.requires_grad:
|
||||
x = x.masked_fill(make_pad_mask(ilens, x, 1), 0.0)
|
||||
else:
|
||||
x.masked_fill_(make_pad_mask(ilens, x, 1), 0.0)
|
||||
# mean: (B, 1, D)
|
||||
mean = x.sum(dim=1, keepdim=True) / ilens_
|
||||
|
||||
if norm_means:
|
||||
x -= mean
|
||||
|
||||
if norm_vars:
|
||||
var = x.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x = x / std.sqrt()
|
||||
return x, ilens
|
||||
else:
|
||||
if norm_vars:
|
||||
y = x - mean
|
||||
y.masked_fill_(make_pad_mask(ilens, y, 1), 0.0)
|
||||
var = y.pow(2).sum(dim=1, keepdim=True) / ilens_
|
||||
std = torch.clamp(var.sqrt(), min=eps)
|
||||
x /= std
|
||||
return x, ilens
|
49
pre4ppg.py
Normal file
49
pre4ppg.py
Normal file
|
@ -0,0 +1,49 @@
|
|||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
from ppg2mel.preprocess import preprocess_dataset
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
|
||||
recognized_datasets = [
|
||||
"aidatatang_200zh",
|
||||
"aidatatang_200zh_s", # sample
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, to be used by the "
|
||||
"ppg2mel model for training.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
parser.add_argument("datasets_root", type=Path, help=\
|
||||
"Path to the directory containing your datasets.")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: aidatatang_200zh.")
|
||||
parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
|
||||
"Path to the output directory that will contain the mel spectrograms, the audios and the "
|
||||
"embeds. Defaults to <datasets_root>/PPGVC/ppg2mel/")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=8, help=\
|
||||
"Number of processes in parallel.")
|
||||
# parser.add_argument("-s", "--skip_existing", action="store_true", help=\
|
||||
# "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
|
||||
# "interrupted. ")
|
||||
# parser.add_argument("--hparams", type=str, default="", help=\
|
||||
# "Hyperparameter overrides as a comma-separated list of name-value pairs")
|
||||
# parser.add_argument("--no_trim", action="store_true", help=\
|
||||
# "Preprocess audio without trimming silences (not recommended).")
|
||||
parser.add_argument("-pf", "--ppg_encoder_model_fpath", type=Path, default="ppg_extractor/saved_models/24epoch.pt", help=\
|
||||
"Path your trained ppg encoder model.")
|
||||
parser.add_argument("-sf", "--speaker_encoder_model", type=Path, default="encoder/saved_models/pretrained_bak_5805000.pt", help=\
|
||||
"Path your trained speaker encoder model.")
|
||||
args = parser.parse_args()
|
||||
|
||||
assert args.dataset in recognized_datasets, 'is not supported, file a issue to propose a new one'
|
||||
|
||||
# Create directories
|
||||
assert args.datasets_root.exists()
|
||||
if not hasattr(args, "out_dir"):
|
||||
args.out_dir = args.datasets_root.joinpath("PPGVC", "ppg2mel")
|
||||
args.out_dir.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
preprocess_dataset(**vars(args))
|
|
@ -17,7 +17,10 @@ webrtcvad; platform_system != "Windows"
|
|||
pypinyin
|
||||
flask
|
||||
flask_wtf
|
||||
flask_cors
|
||||
flask_cors==3.0.10
|
||||
gevent==21.8.0
|
||||
flask_restx
|
||||
tensorboard
|
||||
PyYAML==5.4.1
|
||||
torch_complex
|
||||
espnet
|
142
run.py
Normal file
142
run.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
import time
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import glob
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
from ppg_extractor import load_model
|
||||
import librosa
|
||||
import soundfile as sf
|
||||
from utils.load_yaml import HpsYaml
|
||||
|
||||
from encoder.audio import preprocess_wav
|
||||
from encoder import inference as speacker_encoder
|
||||
from vocoder.hifigan import inference as vocoder
|
||||
from ppg2mel import MelDecoderMOLv2
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
|
||||
|
||||
def _build_ppg2mel_model(model_config, model_file, device):
|
||||
ppg2mel_model = MelDecoderMOLv2(
|
||||
**model_config["model"]
|
||||
).to(device)
|
||||
ckpt = torch.load(model_file, map_location=device)
|
||||
ppg2mel_model.load_state_dict(ckpt["model"])
|
||||
ppg2mel_model.eval()
|
||||
return ppg2mel_model
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert(args):
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
output_dir = args.output_dir
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
step = os.path.basename(args.ppg2mel_model_file)[:-4].split("_")[-1]
|
||||
|
||||
# Build models
|
||||
print("Load PPG-model, PPG2Mel-model, Vocoder-model...")
|
||||
ppg_model = load_model(
|
||||
Path('./ppg_extractor/saved_models/24epoch.pt'),
|
||||
device,
|
||||
)
|
||||
ppg2mel_model = _build_ppg2mel_model(HpsYaml(args.ppg2mel_model_train_config), args.ppg2mel_model_file, device)
|
||||
# vocoder.load_model('./vocoder/saved_models/pretrained/g_hifigan.pt', "./vocoder/hifigan/config_16k_.json")
|
||||
vocoder.load_model('./vocoder/saved_models/24k/g_02830000.pt')
|
||||
# Data related
|
||||
ref_wav_path = args.ref_wav_path
|
||||
ref_wav = preprocess_wav(ref_wav_path)
|
||||
ref_fid = os.path.basename(ref_wav_path)[:-4]
|
||||
|
||||
# TODO: specify encoder
|
||||
speacker_encoder.load_model(Path("encoder/saved_models/pretrained_bak_5805000.pt"))
|
||||
ref_spk_dvec = speacker_encoder.embed_utterance(ref_wav)
|
||||
ref_spk_dvec = torch.from_numpy(ref_spk_dvec).unsqueeze(0).to(device)
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
|
||||
source_file_list = sorted(glob.glob(f"{args.wav_dir}/*.wav"))
|
||||
print(f"Number of source utterances: {len(source_file_list)}.")
|
||||
|
||||
total_rtf = 0.0
|
||||
cnt = 0
|
||||
for src_wav_path in tqdm(source_file_list):
|
||||
# Load the audio to a numpy array:
|
||||
src_wav, _ = librosa.load(src_wav_path, sr=16000)
|
||||
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(device)
|
||||
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(device)
|
||||
ppg = ppg_model(src_wav_tensor, src_wav_lengths)
|
||||
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
|
||||
start = time.time()
|
||||
_, mel_pred, att_ws = ppg2mel_model.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=ref_spk_dvec,
|
||||
)
|
||||
src_fid = os.path.basename(src_wav_path)[:-4]
|
||||
wav_fname = f"{output_dir}/vc_{src_fid}_ref_{ref_fid}_step{step}.wav"
|
||||
mel_len = mel_pred.shape[0]
|
||||
rtf = (time.time() - start) / (0.01 * mel_len)
|
||||
total_rtf += rtf
|
||||
cnt += 1
|
||||
# continue
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
y, output_sample_rate = vocoder.infer_waveform(mel_pred.cpu())
|
||||
sf.write(wav_fname, y.squeeze(), output_sample_rate, "PCM_16")
|
||||
|
||||
print("RTF:")
|
||||
print(total_rtf / cnt)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(description="Conversion from wave input")
|
||||
parser.add_argument(
|
||||
"--wav_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Source wave directory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ref_wav_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Reference wave file path.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ppg2mel_model_train_config", "-c",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Training config file (yaml file)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ppg2mel_model_file", "-m",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="ppg2mel model checkpoint file path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", "-o",
|
||||
type=str,
|
||||
default="vc_gens_vctk_oneshot",
|
||||
help="Output folder to save the converted wave."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
convert(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -3,16 +3,17 @@ from encoder import inference as encoder
|
|||
from synthesizer.inference import Synthesizer
|
||||
from vocoder.wavernn import inference as rnn_vocoder
|
||||
from vocoder.hifigan import inference as gan_vocoder
|
||||
import ppg_extractor as extractor
|
||||
import ppg2mel as convertor
|
||||
from pathlib import Path
|
||||
from time import perf_counter as timer
|
||||
from toolbox.utterance import Utterance
|
||||
from utils.f0_utils import compute_f0, f02lf0, compute_mean_std, get_converted_lf0uv
|
||||
import numpy as np
|
||||
import traceback
|
||||
import sys
|
||||
import torch
|
||||
import librosa
|
||||
import re
|
||||
from audioread.exceptions import NoBackendError
|
||||
|
||||
# 默认使用wavernn
|
||||
vocoder = rnn_vocoder
|
||||
|
@ -49,14 +50,20 @@ recognized_datasets = [
|
|||
MAX_WAVES = 15
|
||||
|
||||
class Toolbox:
|
||||
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support):
|
||||
def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed, no_mp3_support, vc_mode):
|
||||
self.no_mp3_support = no_mp3_support
|
||||
self.vc_mode = vc_mode
|
||||
sys.excepthook = self.excepthook
|
||||
self.datasets_root = datasets_root
|
||||
self.utterances = set()
|
||||
self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
|
||||
|
||||
self.synthesizer = None # type: Synthesizer
|
||||
|
||||
# for ppg-based voice conversion
|
||||
self.extractor = None
|
||||
self.convertor = None # ppg2mel
|
||||
|
||||
self.current_wav = None
|
||||
self.waves_list = []
|
||||
self.waves_count = 0
|
||||
|
@ -70,9 +77,9 @@ class Toolbox:
|
|||
self.trim_silences = False
|
||||
|
||||
# Initialize the events and the interface
|
||||
self.ui = UI()
|
||||
self.ui = UI(vc_mode)
|
||||
self.style_idx = 0
|
||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
|
||||
self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, extractor_models_dir, convertor_models_dir, seed)
|
||||
self.setup_events()
|
||||
self.ui.start()
|
||||
|
||||
|
@ -96,7 +103,11 @@ class Toolbox:
|
|||
self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
|
||||
def func():
|
||||
self.synthesizer = None
|
||||
if self.vc_mode:
|
||||
self.ui.extractor_box.currentIndexChanged.connect(self.init_extractor)
|
||||
else:
|
||||
self.ui.synthesizer_box.currentIndexChanged.connect(func)
|
||||
|
||||
self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
|
||||
|
||||
# Utterance selection
|
||||
|
@ -109,6 +120,11 @@ class Toolbox:
|
|||
self.ui.stop_button.clicked.connect(self.ui.stop)
|
||||
self.ui.record_button.clicked.connect(self.record)
|
||||
|
||||
# Source Utterance selection
|
||||
if self.vc_mode:
|
||||
func = lambda: self.load_soruce_button(self.ui.selected_utterance)
|
||||
self.ui.load_soruce_button.clicked.connect(func)
|
||||
|
||||
#Audio
|
||||
self.ui.setup_audio_devices(Synthesizer.sample_rate)
|
||||
|
||||
|
@ -120,11 +136,16 @@ class Toolbox:
|
|||
self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
|
||||
|
||||
# Generation
|
||||
self.ui.vocode_button.clicked.connect(self.vocode)
|
||||
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
||||
|
||||
if self.vc_mode:
|
||||
func = lambda: self.convert() or self.vocode()
|
||||
self.ui.convert_button.clicked.connect(func)
|
||||
else:
|
||||
func = lambda: self.synthesize() or self.vocode()
|
||||
self.ui.generate_button.clicked.connect(func)
|
||||
self.ui.synthesize_button.clicked.connect(self.synthesize)
|
||||
self.ui.vocode_button.clicked.connect(self.vocode)
|
||||
self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
|
||||
|
||||
# UMAP legend
|
||||
self.ui.clear_button.clicked.connect(self.clear_utterances)
|
||||
|
@ -138,9 +159,9 @@ class Toolbox:
|
|||
def replay_last_wav(self):
|
||||
self.ui.play(self.current_wav, Synthesizer.sample_rate)
|
||||
|
||||
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
|
||||
def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, seed):
|
||||
self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
|
||||
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
|
||||
self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, extractor_models_dir, convertor_models_dir, self.vc_mode)
|
||||
self.ui.populate_gen_options(seed, self.trim_silences)
|
||||
|
||||
def load_from_browser(self, fpath=None):
|
||||
|
@ -172,6 +193,9 @@ class Toolbox:
|
|||
|
||||
self.add_real_utterance(wav, name, speaker_name)
|
||||
|
||||
def load_soruce_button(self, utterance: Utterance):
|
||||
self.selected_source_utterance = utterance
|
||||
|
||||
def record(self):
|
||||
wav = self.ui.record_one(encoder.sampling_rate, 5)
|
||||
if wav is None:
|
||||
|
@ -196,7 +220,7 @@ class Toolbox:
|
|||
# Add the utterance
|
||||
utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
|
||||
self.utterances.add(utterance)
|
||||
self.ui.register_utterance(utterance)
|
||||
self.ui.register_utterance(utterance, self.vc_mode)
|
||||
|
||||
# Plot it
|
||||
self.ui.draw_embed(embed, name, "current")
|
||||
|
@ -269,7 +293,7 @@ class Toolbox:
|
|||
self.ui.set_loading(i, seq_len)
|
||||
if self.ui.current_vocoder_fpath is not None:
|
||||
self.ui.log("")
|
||||
wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
||||
wav, sample_rate = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
|
||||
else:
|
||||
self.ui.log("Waveform generation with Griffin-Lim... ")
|
||||
wav = Synthesizer.griffin_lim(spec)
|
||||
|
@ -280,7 +304,7 @@ class Toolbox:
|
|||
b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
|
||||
b_starts = np.concatenate(([0], b_ends[:-1]))
|
||||
wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
|
||||
breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
|
||||
breaks = [np.zeros(int(0.15 * sample_rate))] * len(breaks)
|
||||
wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
|
||||
|
||||
# Trim excessive silences
|
||||
|
@ -289,7 +313,7 @@ class Toolbox:
|
|||
|
||||
# Play it
|
||||
wav = wav / np.abs(wav).max() * 0.97
|
||||
self.ui.play(wav, Synthesizer.sample_rate)
|
||||
self.ui.play(wav, sample_rate)
|
||||
|
||||
# Name it (history displayed in combobox)
|
||||
# TODO better naming for the combobox items?
|
||||
|
@ -331,6 +355,68 @@ class Toolbox:
|
|||
self.ui.draw_embed(embed, name, "generated")
|
||||
self.ui.draw_umap_projections(self.utterances)
|
||||
|
||||
def convert(self):
|
||||
self.ui.log("Extract PPG and Converting...")
|
||||
self.ui.set_loading(1)
|
||||
|
||||
# Init
|
||||
if self.convertor is None:
|
||||
self.init_convertor()
|
||||
if self.extractor is None:
|
||||
self.init_extractor()
|
||||
|
||||
src_wav = self.selected_source_utterance.wav
|
||||
|
||||
# Compute the ppg
|
||||
if not self.extractor is None:
|
||||
ppg = self.extractor.extract_from_wav(src_wav)
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
ref_wav = self.ui.selected_utterance.wav
|
||||
ref_lf0_mean, ref_lf0_std = compute_mean_std(f02lf0(compute_f0(ref_wav)))
|
||||
lf0_uv = get_converted_lf0uv(src_wav, ref_lf0_mean, ref_lf0_std, convert=True)
|
||||
min_len = min(ppg.shape[1], len(lf0_uv))
|
||||
ppg = ppg[:, :min_len]
|
||||
lf0_uv = lf0_uv[:min_len]
|
||||
_, mel_pred, att_ws = self.convertor.inference(
|
||||
ppg,
|
||||
logf0_uv=torch.from_numpy(lf0_uv).unsqueeze(0).float().to(device),
|
||||
spembs=torch.from_numpy(self.ui.selected_utterance.embed).unsqueeze(0).to(device),
|
||||
)
|
||||
mel_pred= mel_pred.transpose(0, 1)
|
||||
breaks = [mel_pred.shape[1]]
|
||||
mel_pred= mel_pred.detach().cpu().numpy()
|
||||
self.ui.draw_spec(mel_pred, "generated")
|
||||
self.current_generated = (self.ui.selected_utterance.speaker_name, mel_pred, breaks, None)
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_extractor(self):
|
||||
if self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_extractor_fpath
|
||||
self.ui.log("Loading the extractor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
self.extractor = extractor.load_model(model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_convertor(self):
|
||||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_fpath = self.ui.current_convertor_fpath
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.yaml"))
|
||||
if self.ui.current_convertor_fpath is None:
|
||||
return
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
self.ui.log("Loading the convertor %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
self.convertor = convertor.load_model(model_config_fpath, model_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
def init_encoder(self):
|
||||
model_fpath = self.ui.current_encoder_fpath
|
||||
|
||||
|
@ -358,12 +444,16 @@ class Toolbox:
|
|||
# Case of Griffin-lim
|
||||
if model_fpath is None:
|
||||
return
|
||||
|
||||
|
||||
# Sekect vocoder based on model name
|
||||
model_config_fpath = None
|
||||
if model_fpath.name[0] == "g":
|
||||
vocoder = gan_vocoder
|
||||
self.ui.log("set hifigan as vocoder")
|
||||
# search a config file
|
||||
model_config_fpaths = list(model_fpath.parent.rglob("*.json"))
|
||||
if self.ui.current_extractor_fpath is None:
|
||||
return
|
||||
model_config_fpath = model_config_fpaths[0]
|
||||
else:
|
||||
vocoder = rnn_vocoder
|
||||
self.ui.log("set wavernn as vocoder")
|
||||
|
@ -371,7 +461,7 @@ class Toolbox:
|
|||
self.ui.log("Loading the vocoder %s... " % model_fpath)
|
||||
self.ui.set_loading(1)
|
||||
start = timer()
|
||||
vocoder.load_model(model_fpath)
|
||||
vocoder.load_model(model_fpath, model_config_fpath)
|
||||
self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
|
||||
self.ui.set_loading(0)
|
||||
|
||||
|
|
|
@ -326,14 +326,35 @@ class UI(QDialog):
|
|||
def current_vocoder_fpath(self):
|
||||
return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
|
||||
|
||||
@property
|
||||
def current_extractor_fpath(self):
|
||||
return self.extractor_box.itemData(self.extractor_box.currentIndex())
|
||||
|
||||
@property
|
||||
def current_convertor_fpath(self):
|
||||
return self.convertor_box.itemData(self.convertor_box.currentIndex())
|
||||
|
||||
def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
|
||||
vocoder_models_dir: Path):
|
||||
vocoder_models_dir: Path, extractor_models_dir: Path, convertor_models_dir: Path, vc_mode: bool):
|
||||
# Encoder
|
||||
encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
|
||||
if len(encoder_fpaths) == 0:
|
||||
raise Exception("No encoder models found in %s" % encoder_models_dir)
|
||||
self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
|
||||
|
||||
if vc_mode:
|
||||
# Extractor
|
||||
extractor_fpaths = list(extractor_models_dir.glob("*.pt"))
|
||||
if len(extractor_fpaths) == 0:
|
||||
self.log("No extractor models found in %s" % extractor_fpaths)
|
||||
self.repopulate_box(self.extractor_box, [(f.stem, f) for f in extractor_fpaths])
|
||||
|
||||
# Convertor
|
||||
convertor_fpaths = list(convertor_models_dir.glob("*.pth"))
|
||||
if len(convertor_fpaths) == 0:
|
||||
self.log("No convertor models found in %s" % convertor_fpaths)
|
||||
self.repopulate_box(self.convertor_box, [(f.stem, f) for f in convertor_fpaths])
|
||||
else:
|
||||
# Synthesizer
|
||||
synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
|
||||
if len(synthesizer_fpaths) == 0:
|
||||
|
@ -349,7 +370,7 @@ class UI(QDialog):
|
|||
def selected_utterance(self):
|
||||
return self.utterance_history.itemData(self.utterance_history.currentIndex())
|
||||
|
||||
def register_utterance(self, utterance: Utterance):
|
||||
def register_utterance(self, utterance: Utterance, vc_mode):
|
||||
self.utterance_history.blockSignals(True)
|
||||
self.utterance_history.insertItem(0, utterance.name, utterance)
|
||||
self.utterance_history.setCurrentIndex(0)
|
||||
|
@ -359,6 +380,9 @@ class UI(QDialog):
|
|||
self.utterance_history.removeItem(self.max_saved_utterances)
|
||||
|
||||
self.play_button.setDisabled(False)
|
||||
if vc_mode:
|
||||
self.convert_button.setDisabled(False)
|
||||
else:
|
||||
self.generate_button.setDisabled(False)
|
||||
self.synthesize_button.setDisabled(False)
|
||||
|
||||
|
@ -402,7 +426,7 @@ class UI(QDialog):
|
|||
else:
|
||||
self.seed_textbox.setEnabled(False)
|
||||
|
||||
def reset_interface(self):
|
||||
def reset_interface(self, vc_mode):
|
||||
self.draw_embed(None, None, "current")
|
||||
self.draw_embed(None, None, "generated")
|
||||
self.draw_spec(None, "current")
|
||||
|
@ -410,6 +434,9 @@ class UI(QDialog):
|
|||
self.draw_umap_projections(set())
|
||||
self.set_loading(0)
|
||||
self.play_button.setDisabled(True)
|
||||
if vc_mode:
|
||||
self.convert_button.setDisabled(True)
|
||||
else:
|
||||
self.generate_button.setDisabled(True)
|
||||
self.synthesize_button.setDisabled(True)
|
||||
self.vocode_button.setDisabled(True)
|
||||
|
@ -417,7 +444,7 @@ class UI(QDialog):
|
|||
self.export_wav_button.setDisabled(True)
|
||||
[self.log("") for _ in range(self.max_log_lines)]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, vc_mode):
|
||||
## Initialize the application
|
||||
self.app = QApplication(sys.argv)
|
||||
super().__init__(None)
|
||||
|
@ -469,7 +496,7 @@ class UI(QDialog):
|
|||
source_groupbox = QGroupBox('Source(源音频)')
|
||||
source_layout = QGridLayout()
|
||||
source_groupbox.setLayout(source_layout)
|
||||
browser_layout.addWidget(source_groupbox, i, 0, 1, 4)
|
||||
browser_layout.addWidget(source_groupbox, i, 0, 1, 5)
|
||||
|
||||
self.dataset_box = QComboBox()
|
||||
source_layout.addWidget(QLabel("Dataset(数据集):"), i, 0)
|
||||
|
@ -510,25 +537,35 @@ class UI(QDialog):
|
|||
browser_layout.addWidget(self.play_button, i, 2)
|
||||
self.stop_button = QPushButton("Stop(暂停)")
|
||||
browser_layout.addWidget(self.stop_button, i, 3)
|
||||
if vc_mode:
|
||||
self.load_soruce_button = QPushButton("Select(选择为被转换的语音输入)")
|
||||
browser_layout.addWidget(self.load_soruce_button, i, 4)
|
||||
|
||||
i += 1
|
||||
model_groupbox = QGroupBox('Models(模型选择)')
|
||||
model_layout = QHBoxLayout()
|
||||
model_groupbox.setLayout(model_layout)
|
||||
browser_layout.addWidget(model_groupbox, i, 0, 1, 4)
|
||||
browser_layout.addWidget(model_groupbox, i, 0, 2, 5)
|
||||
|
||||
# Model and audio output selection
|
||||
self.encoder_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Encoder:"))
|
||||
model_layout.addWidget(self.encoder_box)
|
||||
self.synthesizer_box = QComboBox()
|
||||
if vc_mode:
|
||||
self.extractor_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Extractor:"))
|
||||
model_layout.addWidget(self.extractor_box)
|
||||
self.convertor_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Convertor:"))
|
||||
model_layout.addWidget(self.convertor_box)
|
||||
else:
|
||||
model_layout.addWidget(QLabel("Synthesizer:"))
|
||||
model_layout.addWidget(self.synthesizer_box)
|
||||
self.vocoder_box = QComboBox()
|
||||
model_layout.addWidget(QLabel("Vocoder:"))
|
||||
model_layout.addWidget(self.vocoder_box)
|
||||
|
||||
|
||||
#Replay & Save Audio
|
||||
i = 0
|
||||
output_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
|
||||
|
@ -550,7 +587,7 @@ class UI(QDialog):
|
|||
|
||||
## Embed & spectrograms
|
||||
vis_layout.addStretch()
|
||||
|
||||
# TODO: add spectrograms for source
|
||||
gridspec_kw = {"width_ratios": [1, 4]}
|
||||
fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
|
||||
gridspec_kw=gridspec_kw)
|
||||
|
@ -571,16 +608,23 @@ class UI(QDialog):
|
|||
self.text_prompt = QPlainTextEdit(default_text)
|
||||
gen_layout.addWidget(self.text_prompt, stretch=1)
|
||||
|
||||
if vc_mode:
|
||||
layout = QHBoxLayout()
|
||||
self.convert_button = QPushButton("Extract and Convert")
|
||||
layout.addWidget(self.convert_button)
|
||||
gen_layout.addLayout(layout)
|
||||
else:
|
||||
self.generate_button = QPushButton("Synthesize and vocode")
|
||||
gen_layout.addWidget(self.generate_button)
|
||||
|
||||
layout = QHBoxLayout()
|
||||
self.synthesize_button = QPushButton("Synthesize only")
|
||||
layout.addWidget(self.synthesize_button)
|
||||
|
||||
self.vocode_button = QPushButton("Vocode only")
|
||||
layout.addWidget(self.vocode_button)
|
||||
gen_layout.addLayout(layout)
|
||||
|
||||
|
||||
layout_seed = QGridLayout()
|
||||
self.random_seed_checkbox = QCheckBox("Random seed:")
|
||||
self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
|
||||
|
@ -648,7 +692,7 @@ class UI(QDialog):
|
|||
self.resize(max_size)
|
||||
|
||||
## Finalize the display
|
||||
self.reset_interface()
|
||||
self.reset_interface(vc_mode)
|
||||
self.show()
|
||||
|
||||
def start(self):
|
||||
|
|
67
train.py
Normal file
67
train.py
Normal file
|
@ -0,0 +1,67 @@
|
|||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
60
utils/audio_utils.py
Normal file
60
utils/audio_utils.py
Normal file
|
@ -0,0 +1,60 @@
|
|||
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from scipy.io.wavfile import read
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
|
||||
|
||||
def load_wav(full_path):
|
||||
sampling_rate, data = read(full_path)
|
||||
return data, sampling_rate
|
||||
|
||||
def _dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def _spectral_normalize_torch(magnitudes):
|
||||
output = _dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
mel_basis = {}
|
||||
hann_window = {}
|
||||
|
||||
def mel_spectrogram(
|
||||
y,
|
||||
n_fft,
|
||||
num_mels,
|
||||
sampling_rate,
|
||||
hop_size,
|
||||
win_size,
|
||||
fmin,
|
||||
fmax,
|
||||
center=False,
|
||||
output_energy=False,
|
||||
):
|
||||
if torch.min(y) < -1.:
|
||||
print('min value is ', torch.min(y))
|
||||
if torch.max(y) > 1.:
|
||||
print('max value is ', torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
if fmax not in mel_basis:
|
||||
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
||||
mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
||||
center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
|
||||
mel_spec = _spectral_normalize_torch(mel_spec)
|
||||
if output_energy:
|
||||
energy = torch.norm(spec, dim=1)
|
||||
return mel_spec, energy
|
||||
else:
|
||||
return mel_spec
|
214
utils/data_load.py
Normal file
214
utils/data_load.py
Normal file
|
@ -0,0 +1,214 @@
|
|||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
from utils.f0_utils import get_cont_lf0
|
||||
import resampy
|
||||
from .audio_utils import MAX_WAV_VALUE, load_wav, mel_spectrogram
|
||||
from librosa.util import normalize
|
||||
import os
|
||||
|
||||
|
||||
SAMPLE_RATE=16000
|
||||
|
||||
def read_fids(fid_list_f):
|
||||
with open(fid_list_f, 'r') as f:
|
||||
fids = [l.strip().split()[0] for l in f if l.strip()]
|
||||
return fids
|
||||
|
||||
class OneshotVcDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
meta_file: str,
|
||||
vctk_ppg_dir: str,
|
||||
libri_ppg_dir: str,
|
||||
vctk_f0_dir: str,
|
||||
libri_f0_dir: str,
|
||||
vctk_wav_dir: str,
|
||||
libri_wav_dir: str,
|
||||
vctk_spk_dvec_dir: str,
|
||||
libri_spk_dvec_dir: str,
|
||||
min_max_norm_mel: bool = False,
|
||||
mel_min: float = None,
|
||||
mel_max: float = None,
|
||||
ppg_file_ext: str = "ling_feat.npy",
|
||||
f0_file_ext: str = "f0.npy",
|
||||
wav_file_ext: str = "wav",
|
||||
):
|
||||
self.fid_list = read_fids(meta_file)
|
||||
self.vctk_ppg_dir = vctk_ppg_dir
|
||||
self.libri_ppg_dir = libri_ppg_dir
|
||||
self.vctk_f0_dir = vctk_f0_dir
|
||||
self.libri_f0_dir = libri_f0_dir
|
||||
self.vctk_wav_dir = vctk_wav_dir
|
||||
self.libri_wav_dir = libri_wav_dir
|
||||
self.vctk_spk_dvec_dir = vctk_spk_dvec_dir
|
||||
self.libri_spk_dvec_dir = libri_spk_dvec_dir
|
||||
|
||||
self.ppg_file_ext = ppg_file_ext
|
||||
self.f0_file_ext = f0_file_ext
|
||||
self.wav_file_ext = wav_file_ext
|
||||
|
||||
self.min_max_norm_mel = min_max_norm_mel
|
||||
if min_max_norm_mel:
|
||||
print("[INFO] Min-Max normalize Melspec.")
|
||||
assert mel_min is not None
|
||||
assert mel_max is not None
|
||||
self.mel_max = mel_max
|
||||
self.mel_min = mel_min
|
||||
|
||||
random.seed(1234)
|
||||
random.shuffle(self.fid_list)
|
||||
print(f'[INFO] Got {len(self.fid_list)} samples.')
|
||||
|
||||
def __len__(self):
|
||||
return len(self.fid_list)
|
||||
|
||||
def get_spk_dvec(self, fid):
|
||||
spk_name = fid
|
||||
if spk_name.startswith("p"):
|
||||
spk_dvec_path = f"{self.vctk_spk_dvec_dir}{os.sep}{spk_name}.npy"
|
||||
else:
|
||||
spk_dvec_path = f"{self.libri_spk_dvec_dir}{os.sep}{spk_name}.npy"
|
||||
return torch.from_numpy(np.load(spk_dvec_path))
|
||||
|
||||
def compute_mel(self, wav_path):
|
||||
audio, sr = load_wav(wav_path)
|
||||
if sr != SAMPLE_RATE:
|
||||
audio = resampy.resample(audio, sr, SAMPLE_RATE)
|
||||
audio = audio / MAX_WAV_VALUE
|
||||
audio = normalize(audio) * 0.95
|
||||
audio = torch.FloatTensor(audio).unsqueeze(0)
|
||||
melspec = mel_spectrogram(
|
||||
audio,
|
||||
n_fft=1024,
|
||||
num_mels=80,
|
||||
sampling_rate=SAMPLE_RATE,
|
||||
hop_size=160,
|
||||
win_size=1024,
|
||||
fmin=80,
|
||||
fmax=8000,
|
||||
)
|
||||
return melspec.squeeze(0).numpy().T
|
||||
|
||||
def bin_level_min_max_norm(self, melspec):
|
||||
# frequency bin level min-max normalization to [-4, 4]
|
||||
mel = (melspec - self.mel_min) / (self.mel_max - self.mel_min) * 8.0 - 4.0
|
||||
return np.clip(mel, -4., 4.)
|
||||
|
||||
def __getitem__(self, index):
|
||||
fid = self.fid_list[index]
|
||||
|
||||
# 1. Load features
|
||||
if fid.startswith("p"):
|
||||
# vctk
|
||||
sub = fid.split("_")[0]
|
||||
ppg = np.load(f"{self.vctk_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
|
||||
f0 = np.load(f"{self.vctk_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
|
||||
mel = self.compute_mel(f"{self.vctk_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
|
||||
else:
|
||||
# aidatatang
|
||||
sub = fid[5:10]
|
||||
ppg = np.load(f"{self.libri_ppg_dir}{os.sep}{fid}.{self.ppg_file_ext}")
|
||||
f0 = np.load(f"{self.libri_f0_dir}{os.sep}{fid}.{self.f0_file_ext}")
|
||||
mel = self.compute_mel(f"{self.libri_wav_dir}{os.sep}{sub}{os.sep}{fid}.{self.wav_file_ext}")
|
||||
if self.min_max_norm_mel:
|
||||
mel = self.bin_level_min_max_norm(mel)
|
||||
|
||||
f0, ppg, mel = self._adjust_lengths(f0, ppg, mel, fid)
|
||||
spk_dvec = self.get_spk_dvec(fid)
|
||||
|
||||
# 2. Convert f0 to continuous log-f0 and u/v flags
|
||||
uv, cont_lf0 = get_cont_lf0(f0, 10.0, False)
|
||||
# cont_lf0 = (cont_lf0 - np.amin(cont_lf0)) / (np.amax(cont_lf0) - np.amin(cont_lf0))
|
||||
# cont_lf0 = self.utt_mvn(cont_lf0)
|
||||
lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||
|
||||
# uv, cont_f0 = convert_continuous_f0(f0)
|
||||
# cont_f0 = (cont_f0 - np.amin(cont_f0)) / (np.amax(cont_f0) - np.amin(cont_f0))
|
||||
# lf0_uv = np.concatenate([cont_f0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||
|
||||
# 3. Convert numpy array to torch.tensor
|
||||
ppg = torch.from_numpy(ppg)
|
||||
lf0_uv = torch.from_numpy(lf0_uv)
|
||||
mel = torch.from_numpy(mel)
|
||||
|
||||
return (ppg, lf0_uv, mel, spk_dvec, fid)
|
||||
|
||||
def check_lengths(self, f0, ppg, mel, fid):
|
||||
LEN_THRESH = 10
|
||||
assert abs(len(ppg) - len(f0)) <= LEN_THRESH, \
|
||||
f"{abs(len(ppg) - len(f0))}: for file {fid}"
|
||||
assert abs(len(mel) - len(f0)) <= LEN_THRESH, \
|
||||
f"{abs(len(mel) - len(f0))}: for file {fid}"
|
||||
|
||||
def _adjust_lengths(self, f0, ppg, mel, fid):
|
||||
self.check_lengths(f0, ppg, mel, fid)
|
||||
min_len = min(
|
||||
len(f0),
|
||||
len(ppg),
|
||||
len(mel),
|
||||
)
|
||||
f0 = f0[:min_len]
|
||||
ppg = ppg[:min_len]
|
||||
mel = mel[:min_len]
|
||||
return f0, ppg, mel
|
||||
|
||||
class MultiSpkVcCollate():
|
||||
"""Zero-pads model inputs and targets based on number of frames per step
|
||||
"""
|
||||
def __init__(self, n_frames_per_step=1, give_uttids=False,
|
||||
f02ppg_length_ratio=1, use_spk_dvec=False):
|
||||
self.n_frames_per_step = n_frames_per_step
|
||||
self.give_uttids = give_uttids
|
||||
self.f02ppg_length_ratio = f02ppg_length_ratio
|
||||
self.use_spk_dvec = use_spk_dvec
|
||||
|
||||
def __call__(self, batch):
|
||||
batch_size = len(batch)
|
||||
# Prepare different features
|
||||
ppgs = [x[0] for x in batch]
|
||||
lf0_uvs = [x[1] for x in batch]
|
||||
mels = [x[2] for x in batch]
|
||||
fids = [x[-1] for x in batch]
|
||||
if len(batch[0]) == 5:
|
||||
spk_ids = [x[3] for x in batch]
|
||||
if self.use_spk_dvec:
|
||||
# use d-vector
|
||||
spk_ids = torch.stack(spk_ids).float()
|
||||
else:
|
||||
# use one-hot ids
|
||||
spk_ids = torch.LongTensor(spk_ids)
|
||||
# Pad features into chunk
|
||||
ppg_lengths = [x.shape[0] for x in ppgs]
|
||||
mel_lengths = [x.shape[0] for x in mels]
|
||||
max_ppg_len = max(ppg_lengths)
|
||||
max_mel_len = max(mel_lengths)
|
||||
if max_mel_len % self.n_frames_per_step != 0:
|
||||
max_mel_len += (self.n_frames_per_step - max_mel_len % self.n_frames_per_step)
|
||||
ppg_dim = ppgs[0].shape[1]
|
||||
mel_dim = mels[0].shape[1]
|
||||
ppgs_padded = torch.FloatTensor(batch_size, max_ppg_len, ppg_dim).zero_()
|
||||
mels_padded = torch.FloatTensor(batch_size, max_mel_len, mel_dim).zero_()
|
||||
lf0_uvs_padded = torch.FloatTensor(batch_size, self.f02ppg_length_ratio * max_ppg_len, 2).zero_()
|
||||
stop_tokens = torch.FloatTensor(batch_size, max_mel_len).zero_()
|
||||
for i in range(batch_size):
|
||||
cur_ppg_len = ppgs[i].shape[0]
|
||||
cur_mel_len = mels[i].shape[0]
|
||||
ppgs_padded[i, :cur_ppg_len, :] = ppgs[i]
|
||||
lf0_uvs_padded[i, :self.f02ppg_length_ratio*cur_ppg_len, :] = lf0_uvs[i]
|
||||
mels_padded[i, :cur_mel_len, :] = mels[i]
|
||||
stop_tokens[i, cur_ppg_len-self.n_frames_per_step:] = 1
|
||||
if len(batch[0]) == 5:
|
||||
ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \
|
||||
torch.LongTensor(mel_lengths), spk_ids, stop_tokens)
|
||||
if self.give_uttids:
|
||||
return ret_tup + (fids, )
|
||||
else:
|
||||
return ret_tup
|
||||
else:
|
||||
ret_tup = (ppgs_padded, lf0_uvs_padded, mels_padded, torch.LongTensor(ppg_lengths), \
|
||||
torch.LongTensor(mel_lengths), stop_tokens)
|
||||
if self.give_uttids:
|
||||
return ret_tup + (fids, )
|
||||
else:
|
||||
return ret_tup
|
124
utils/f0_utils.py
Normal file
124
utils/f0_utils.py
Normal file
|
@ -0,0 +1,124 @@
|
|||
import logging
|
||||
import numpy as np
|
||||
import pyworld
|
||||
from scipy.interpolate import interp1d
|
||||
from scipy.signal import firwin, get_window, lfilter
|
||||
|
||||
def compute_mean_std(lf0):
|
||||
nonzero_indices = np.nonzero(lf0)
|
||||
mean = np.mean(lf0[nonzero_indices])
|
||||
std = np.std(lf0[nonzero_indices])
|
||||
return mean, std
|
||||
|
||||
|
||||
def compute_f0(wav, sr=16000, frame_period=10.0):
|
||||
"""Compute f0 from wav using pyworld harvest algorithm."""
|
||||
wav = wav.astype(np.float64)
|
||||
f0, _ = pyworld.harvest(
|
||||
wav, sr, frame_period=frame_period, f0_floor=80.0, f0_ceil=600.0)
|
||||
return f0.astype(np.float32)
|
||||
|
||||
def f02lf0(f0):
|
||||
lf0 = f0.copy()
|
||||
nonzero_indices = np.nonzero(f0)
|
||||
lf0[nonzero_indices] = np.log(f0[nonzero_indices])
|
||||
return lf0
|
||||
|
||||
def get_converted_lf0uv(
|
||||
wav,
|
||||
lf0_mean_trg,
|
||||
lf0_std_trg,
|
||||
convert=True,
|
||||
):
|
||||
f0_src = compute_f0(wav)
|
||||
if not convert:
|
||||
uv, cont_lf0 = get_cont_lf0(f0_src)
|
||||
lf0_uv = np.concatenate([cont_lf0[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||
return lf0_uv
|
||||
|
||||
lf0_src = f02lf0(f0_src)
|
||||
lf0_mean_src, lf0_std_src = compute_mean_std(lf0_src)
|
||||
|
||||
lf0_vc = lf0_src.copy()
|
||||
lf0_vc[lf0_src > 0.0] = (lf0_src[lf0_src > 0.0] - lf0_mean_src) / lf0_std_src * lf0_std_trg + lf0_mean_trg
|
||||
f0_vc = lf0_vc.copy()
|
||||
f0_vc[lf0_src > 0.0] = np.exp(lf0_vc[lf0_src > 0.0])
|
||||
|
||||
uv, cont_lf0_vc = get_cont_lf0(f0_vc)
|
||||
lf0_uv = np.concatenate([cont_lf0_vc[:, np.newaxis], uv[:, np.newaxis]], axis=1)
|
||||
return lf0_uv
|
||||
|
||||
def low_pass_filter(x, fs, cutoff=70, padding=True):
|
||||
"""FUNCTION TO APPLY LOW PASS FILTER
|
||||
|
||||
Args:
|
||||
x (ndarray): Waveform sequence
|
||||
fs (int): Sampling frequency
|
||||
cutoff (float): Cutoff frequency of low pass filter
|
||||
|
||||
Return:
|
||||
(ndarray): Low pass filtered waveform sequence
|
||||
"""
|
||||
|
||||
nyquist = fs // 2
|
||||
norm_cutoff = cutoff / nyquist
|
||||
|
||||
# low cut filter
|
||||
numtaps = 255
|
||||
fil = firwin(numtaps, norm_cutoff)
|
||||
x_pad = np.pad(x, (numtaps, numtaps), 'edge')
|
||||
lpf_x = lfilter(fil, 1, x_pad)
|
||||
lpf_x = lpf_x[numtaps + numtaps // 2: -numtaps // 2]
|
||||
|
||||
return lpf_x
|
||||
|
||||
|
||||
def convert_continuos_f0(f0):
|
||||
"""CONVERT F0 TO CONTINUOUS F0
|
||||
|
||||
Args:
|
||||
f0 (ndarray): original f0 sequence with the shape (T)
|
||||
|
||||
Return:
|
||||
(ndarray): continuous f0 with the shape (T)
|
||||
"""
|
||||
# get uv information as binary
|
||||
uv = np.float32(f0 != 0)
|
||||
|
||||
# get start and end of f0
|
||||
if (f0 == 0).all():
|
||||
logging.warn("all of the f0 values are 0.")
|
||||
return uv, f0
|
||||
start_f0 = f0[f0 != 0][0]
|
||||
end_f0 = f0[f0 != 0][-1]
|
||||
|
||||
# padding start and end of f0 sequence
|
||||
start_idx = np.where(f0 == start_f0)[0][0]
|
||||
end_idx = np.where(f0 == end_f0)[0][-1]
|
||||
f0[:start_idx] = start_f0
|
||||
f0[end_idx:] = end_f0
|
||||
|
||||
# get non-zero frame index
|
||||
nz_frames = np.where(f0 != 0)[0]
|
||||
|
||||
# perform linear interpolation
|
||||
f = interp1d(nz_frames, f0[nz_frames])
|
||||
cont_f0 = f(np.arange(0, f0.shape[0]))
|
||||
|
||||
return uv, cont_f0
|
||||
|
||||
|
||||
def get_cont_lf0(f0, frame_period=10.0, lpf=False):
|
||||
uv, cont_f0 = convert_continuos_f0(f0)
|
||||
if lpf:
|
||||
cont_f0_lpf = low_pass_filter(cont_f0, int(1.0 / (frame_period * 0.001)), cutoff=20)
|
||||
cont_lf0_lpf = cont_f0_lpf.copy()
|
||||
nonzero_indices = np.nonzero(cont_lf0_lpf)
|
||||
cont_lf0_lpf[nonzero_indices] = np.log(cont_f0_lpf[nonzero_indices])
|
||||
# cont_lf0_lpf = np.log(cont_f0_lpf)
|
||||
return uv, cont_lf0_lpf
|
||||
else:
|
||||
nonzero_indices = np.nonzero(cont_f0)
|
||||
cont_lf0 = cont_f0.copy()
|
||||
cont_lf0[cont_f0>0] = np.log(cont_f0[cont_f0>0])
|
||||
return uv, cont_lf0
|
58
utils/load_yaml.py
Normal file
58
utils/load_yaml.py
Normal file
|
@ -0,0 +1,58 @@
|
|||
import yaml
|
||||
|
||||
|
||||
def load_hparams(filename):
|
||||
stream = open(filename, 'r')
|
||||
docs = yaml.safe_load_all(stream)
|
||||
hparams_dict = dict()
|
||||
for doc in docs:
|
||||
for k, v in doc.items():
|
||||
hparams_dict[k] = v
|
||||
return hparams_dict
|
||||
|
||||
def merge_dict(user, default):
|
||||
if isinstance(user, dict) and isinstance(default, dict):
|
||||
for k, v in default.items():
|
||||
if k not in user:
|
||||
user[k] = v
|
||||
else:
|
||||
user[k] = merge_dict(user[k], v)
|
||||
return user
|
||||
|
||||
class Dotdict(dict):
|
||||
"""
|
||||
a dictionary that supports dot notation
|
||||
as well as dictionary access notation
|
||||
usage: d = DotDict() or d = DotDict({'val1':'first'})
|
||||
set attributes: d.val2 = 'second' or d['val2'] = 'second'
|
||||
get attributes: d.val2 or d['val2']
|
||||
"""
|
||||
__getattr__ = dict.__getitem__
|
||||
__setattr__ = dict.__setitem__
|
||||
__delattr__ = dict.__delitem__
|
||||
|
||||
def __init__(self, dct=None):
|
||||
dct = dict() if not dct else dct
|
||||
for key, value in dct.items():
|
||||
if hasattr(value, 'keys'):
|
||||
value = Dotdict(value)
|
||||
self[key] = value
|
||||
|
||||
class HpsYaml(Dotdict):
|
||||
def __init__(self, yaml_file):
|
||||
super(Dotdict, self).__init__()
|
||||
hps = load_hparams(yaml_file)
|
||||
hp_dict = Dotdict(hps)
|
||||
for k, v in hp_dict.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
__getattr__ = Dotdict.__getitem__
|
||||
__setattr__ = Dotdict.__setitem__
|
||||
__delattr__ = Dotdict.__delitem__
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
44
utils/util.py
Normal file
44
utils/util.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
import matplotlib
|
||||
matplotlib.use('Agg')
|
||||
import time
|
||||
|
||||
class Timer():
|
||||
''' Timer for recording training time distribution. '''
|
||||
def __init__(self):
|
||||
self.prev_t = time.time()
|
||||
self.clear()
|
||||
|
||||
def set(self):
|
||||
self.prev_t = time.time()
|
||||
|
||||
def cnt(self, mode):
|
||||
self.time_table[mode] += time.time()-self.prev_t
|
||||
self.set()
|
||||
if mode == 'bw':
|
||||
self.click += 1
|
||||
|
||||
def show(self):
|
||||
total_time = sum(self.time_table.values())
|
||||
self.time_table['avg'] = total_time/self.click
|
||||
self.time_table['rd'] = 100*self.time_table['rd']/total_time
|
||||
self.time_table['fw'] = 100*self.time_table['fw']/total_time
|
||||
self.time_table['bw'] = 100*self.time_table['bw']/total_time
|
||||
msg = '{avg:.3f} sec/step (rd {rd:.1f}% | fw {fw:.1f}% | bw {bw:.1f}%)'.format(
|
||||
**self.time_table)
|
||||
self.clear()
|
||||
return msg
|
||||
|
||||
def clear(self):
|
||||
self.time_table = {'rd': 0, 'fw': 0, 'bw': 0}
|
||||
self.click = 0
|
||||
|
||||
# Reference : https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/e2e_asr.py#L168
|
||||
|
||||
def human_format(num):
|
||||
magnitude = 0
|
||||
while num >= 1000:
|
||||
magnitude += 1
|
||||
num /= 1000.0
|
||||
# add more suffixes if you need them
|
||||
return '{:3.1f}{}'.format(num, [' ', 'K', 'M', 'G', 'T', 'P'][magnitude])
|
||||
|
|
@ -3,14 +3,11 @@ from __future__ import absolute_import, division, print_function, unicode_litera
|
|||
import os
|
||||
import json
|
||||
import torch
|
||||
from scipy.io.wavfile import write
|
||||
from vocoder.hifigan.env import AttrDict
|
||||
from vocoder.hifigan.meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
|
||||
from vocoder.hifigan.models import Generator
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
generator = None # type: Generator
|
||||
output_sample_rate = None
|
||||
_device = None
|
||||
|
||||
|
||||
|
@ -22,16 +19,17 @@ def load_checkpoint(filepath, device):
|
|||
return checkpoint_dict
|
||||
|
||||
|
||||
def load_model(weights_fpath, verbose=True):
|
||||
global generator, _device
|
||||
def load_model(weights_fpath, config_fpath="./vocoder/saved_models/24k/config.json", verbose=True):
|
||||
global generator, _device, output_sample_rate
|
||||
|
||||
if verbose:
|
||||
print("Building hifigan")
|
||||
|
||||
with open("./vocoder/hifigan/config_16k_.json") as f:
|
||||
with open(config_fpath) as f:
|
||||
data = f.read()
|
||||
json_config = json.loads(data)
|
||||
h = AttrDict(json_config)
|
||||
output_sample_rate = h.sampling_rate
|
||||
torch.manual_seed(h.seed)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
|
@ -66,5 +64,5 @@ def infer_waveform(mel, progress_callback=None):
|
|||
audio = y_g_hat.squeeze()
|
||||
audio = audio.cpu().numpy()
|
||||
|
||||
return audio
|
||||
return audio, output_sample_rate
|
||||
|
||||
|
|
|
@ -71,6 +71,24 @@ class ResBlock2(torch.nn.Module):
|
|||
for l in self.convs:
|
||||
remove_weight_norm(l)
|
||||
|
||||
class InterpolationBlock(torch.nn.Module):
|
||||
def __init__(self, scale_factor, mode='nearest', align_corners=None, downsample=False):
|
||||
super(InterpolationBlock, self).__init__()
|
||||
self.downsample = downsample
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x):
|
||||
outputs = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=x.shape[-1] * self.scale_factor \
|
||||
if not self.downsample else x.shape[-1] // self.scale_factor,
|
||||
mode=self.mode,
|
||||
align_corners=self.align_corners,
|
||||
recompute_scale_factor=False
|
||||
)
|
||||
return outputs
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
def __init__(self, h):
|
||||
|
@ -82,14 +100,27 @@ class Generator(torch.nn.Module):
|
|||
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
# for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
# # self.ups.append(weight_norm(
|
||||
# # ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
||||
# # k, u, padding=(k-u)//2)))
|
||||
if h.sampling_rate == 24000:
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
torch.nn.Sequential(
|
||||
InterpolationBlock(u),
|
||||
weight_norm(torch.nn.Conv1d(
|
||||
h.upsample_initial_channel//(2**i),
|
||||
h.upsample_initial_channel//(2**(i+1)),
|
||||
k, padding=(k-1)//2,
|
||||
))
|
||||
)
|
||||
)
|
||||
else:
|
||||
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
# self.ups.append(weight_norm(
|
||||
# ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
|
||||
# k, u, padding=(k-u)//2)))
|
||||
self.ups.append(weight_norm(ConvTranspose1d(h.upsample_initial_channel//(2**i),
|
||||
h.upsample_initial_channel//(2**(i+1)),
|
||||
k, u, padding=(u//2 + u%2), output_padding=u%2)))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel//(2**(i+1))
|
||||
|
@ -121,6 +152,9 @@ class Generator(torch.nn.Module):
|
|||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
if self.h.sampling_rate == 24000:
|
||||
remove_weight_norm(l[-1])
|
||||
else:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
|
|
@ -61,4 +61,4 @@ def infer_waveform(mel, normalize=True, batched=True, target=8000, overlap=800,
|
|||
mel = mel / hp.mel_max_abs_value
|
||||
mel = torch.from_numpy(mel[None, ...])
|
||||
wav = _model.generate(mel, batched, target, overlap, hp.mu_law, progress_callback)
|
||||
return wav
|
||||
return wav, hp.sample_rate
|
||||
|
|
|
@ -107,14 +107,15 @@ def webApp():
|
|||
embeds = [embed] * len(texts)
|
||||
specs = current_synt.synthesize_spectrograms(texts, embeds)
|
||||
spec = np.concatenate(specs, axis=1)
|
||||
sample_rate = Synthesizer.sample_rate
|
||||
if "vocoder" in request.form and request.form["vocoder"] == "WaveRNN":
|
||||
wav = rnn_vocoder.infer_waveform(spec)
|
||||
else:
|
||||
wav = gan_vocoder.infer_waveform(spec)
|
||||
wav, sample_rate = gan_vocoder.infer_waveform(spec)
|
||||
|
||||
# Return cooked wav
|
||||
out = io.BytesIO()
|
||||
write(out, Synthesizer.sample_rate, wav.astype(np.float32))
|
||||
write(out, sample_rate, wav.astype(np.float32))
|
||||
return Response(out, mimetype="audio/wav")
|
||||
|
||||
@app.route('/', methods=['GET'])
|
||||
|
|
Loading…
Reference in New Issue
Block a user