diff --git a/control/cli/ppg2mel_train.py b/control/cli/ppg2mel_train.py index 557bc26..2f17089 100644 --- a/control/cli/ppg2mel_train.py +++ b/control/cli/ppg2mel_train.py @@ -2,7 +2,7 @@ import sys import torch import argparse import numpy as np -from utils.load_yaml import HpsYaml +from utils.hparams import HpsYaml from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver # For reproducibility, comment these may speed up training diff --git a/control/cli/train_ppg2mel.py b/control/cli/train_ppg2mel.py index 0a94e84..4a9eb4f 100644 --- a/control/cli/train_ppg2mel.py +++ b/control/cli/train_ppg2mel.py @@ -2,7 +2,7 @@ import sys import torch import argparse import numpy as np -from utils.load_yaml import HpsYaml +from utils.hparams import HpsYaml from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver # For reproducibility, comment these may speed up training diff --git a/control/mkgui/train_vc.py b/control/mkgui/train_vc.py index 16a1582..24f0ae4 100644 --- a/control/mkgui/train_vc.py +++ b/control/mkgui/train_vc.py @@ -4,7 +4,7 @@ from pathlib import Path from enum import Enum from typing import Any, Tuple import numpy as np -from utils.load_yaml import HpsYaml +from utils.hparams import HpsYaml from utils.util import AttrDict import torch diff --git a/models/ppg2mel/__init__.py b/models/ppg2mel/__init__.py index cc54db8..731e461 100644 --- a/models/ppg2mel/__init__.py +++ b/models/ppg2mel/__init__.py @@ -15,7 +15,7 @@ from .rnn_decoder_mol import Decoder from .utils.cnn_postnet import Postnet from .utils.vc_utils import get_mask_from_lengths -from utils.load_yaml import HpsYaml +from utils.hparams import HpsYaml class MelDecoderMOLv2(AbsMelDecoder): """Use an encoder to preprocess ppg.""" diff --git a/models/ppg2mel/train.py b/models/ppg2mel/train.py index 80aef06..cdc29cc 100644 --- a/models/ppg2mel/train.py +++ b/models/ppg2mel/train.py @@ -2,7 +2,7 @@ import sys import torch import argparse import numpy as np -from utils.load_yaml import HpsYaml +from utils.hparams import HpsYaml from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver # For reproducibility, comment these may speed up training diff --git a/models/ppg2mel/train/solver.py b/models/ppg2mel/train/solver.py index 9ca71cb..93c3d43 100644 --- a/models/ppg2mel/train/solver.py +++ b/models/ppg2mel/train/solver.py @@ -8,7 +8,6 @@ from torch.utils.tensorboard import SummaryWriter from .option import default_hparas from utils.util import human_format, Timer -from utils.load_yaml import HpsYaml class BaseSolver(): diff --git a/models/synthesizer/hparams.py b/models/synthesizer/hparams.py index 8bcdb63..ca3e635 100644 --- a/models/synthesizer/hparams.py +++ b/models/synthesizer/hparams.py @@ -1,36 +1,4 @@ -import ast -import pprint -import json - -class HParams(object): - def __init__(self, **kwargs): self.__dict__.update(kwargs) - def __setitem__(self, key, value): setattr(self, key, value) - def __getitem__(self, key): return getattr(self, key) - def __repr__(self): return pprint.pformat(self.__dict__) - - def parse(self, string): - # Overrides hparams from a comma-separated string of name=value pairs - if len(string) > 0: - overrides = [s.split("=") for s in string.split(",")] - keys, values = zip(*overrides) - keys = list(map(str.strip, keys)) - values = list(map(str.strip, values)) - for k in keys: - self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) - return self - - def loadJson(self, dict): - print("\Loading the json with %s\n", dict) - for k in dict.keys(): - if k not in ["tts_schedule", "tts_finetune_layers"]: - self.__dict__[k] = dict[k] - return self - - def dumpJson(self, fp): - print("\Saving the json with %s\n", fp) - with fp.open("w", encoding="utf-8") as f: - json.dump(self.__dict__, f) - return self +from utils.hparams import HParams hparams = HParams( ### Signal Processing (used in both synthesizer and vocoder) @@ -104,7 +72,7 @@ hparams = HParams( ### SV2TTS speaker_embedding_size = 256, # Dimension for the speaker embedding silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split - utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded + utterance_min_duration = 0.5, # Duration in seconds below which utterances are discarded use_gst = True, # Whether to use global style token use_ser_for_gst = True, # Whether to use speaker embedding referenced for global style token ) diff --git a/models/synthesizer/inference.py b/models/synthesizer/inference.py index 888f31c..f1bedfb 100644 --- a/models/synthesizer/inference.py +++ b/models/synthesizer/inference.py @@ -10,7 +10,6 @@ from typing import Union, List import numpy as np import librosa from utils import logmmse -import json from pypinyin import lazy_pinyin, Style class Synthesizer: @@ -48,8 +47,7 @@ class Synthesizer: # Try to scan config file model_config_fpaths = list(self.model_fpath.parent.rglob("*.json")) if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): - with model_config_fpaths[0].open("r", encoding="utf-8") as f: - hparams.loadJson(json.load(f)) + hparams.loadJson(model_config_fpaths[0]) """ Instantiates and loads the model given the weights file that was passed in the constructor. """ diff --git a/models/synthesizer/models/base.py b/models/synthesizer/models/base.py index 13b32a1..13750df 100644 --- a/models/synthesizer/models/base.py +++ b/models/synthesizer/models/base.py @@ -48,7 +48,11 @@ class Base(nn.Module): def load(self, path, device, optimizer=None): # Use device of model params as location for loaded state checkpoint = torch.load(str(path), map_location=device) - self.load_state_dict(checkpoint["model_state"], strict=False) + if "model_state" in checkpoint: + state = checkpoint["model_state"] + else: + state = checkpoint["model"] + self.load_state_dict(state, strict=False) if "optimizer_state" in checkpoint and optimizer is not None: optimizer.load_state_dict(checkpoint["optimizer_state"]) diff --git a/models/synthesizer/models/sublayer/common/transforms.py b/models/synthesizer/models/sublayer/common/transforms.py new file mode 100644 index 0000000..4793d67 --- /dev/null +++ b/models/synthesizer/models/sublayer/common/transforms.py @@ -0,0 +1,193 @@ +import torch +from torch.nn import functional as F + +import numpy as np + + +DEFAULT_MIN_BIN_WIDTH = 1e-3 +DEFAULT_MIN_BIN_HEIGHT = 1e-3 +DEFAULT_MIN_DERIVATIVE = 1e-3 + + +def piecewise_rational_quadratic_transform(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails=None, + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + + if tails is None: + spline_fn = rational_quadratic_spline + spline_kwargs = {} + else: + spline_fn = unconstrained_rational_quadratic_spline + spline_kwargs = { + 'tails': tails, + 'tail_bound': tail_bound + } + + outputs, logabsdet = spline_fn( + inputs=inputs, + unnormalized_widths=unnormalized_widths, + unnormalized_heights=unnormalized_heights, + unnormalized_derivatives=unnormalized_derivatives, + inverse=inverse, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative, + **spline_kwargs + ) + return outputs, logabsdet + + +def searchsorted(bin_locations, inputs, eps=1e-6): + bin_locations[..., -1] += eps + return torch.sum( + inputs[..., None] >= bin_locations, + dim=-1 + ) - 1 + + +def unconstrained_rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + tails='linear', + tail_bound=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound) + outside_interval_mask = ~inside_interval_mask + + outputs = torch.zeros_like(inputs) + logabsdet = torch.zeros_like(inputs) + + if tails == 'linear': + unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1)) + constant = np.log(np.exp(1 - min_derivative) - 1) + unnormalized_derivatives[..., 0] = constant + unnormalized_derivatives[..., -1] = constant + + outputs[outside_interval_mask] = inputs[outside_interval_mask] + logabsdet[outside_interval_mask] = 0 + else: + raise RuntimeError('{} tails are not implemented.'.format(tails)) + + outputs[inside_interval_mask], logabsdet[inside_interval_mask] = rational_quadratic_spline( + inputs=inputs[inside_interval_mask], + unnormalized_widths=unnormalized_widths[inside_interval_mask, :], + unnormalized_heights=unnormalized_heights[inside_interval_mask, :], + unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :], + inverse=inverse, + left=-tail_bound, right=tail_bound, bottom=-tail_bound, top=tail_bound, + min_bin_width=min_bin_width, + min_bin_height=min_bin_height, + min_derivative=min_derivative + ) + + return outputs, logabsdet + +def rational_quadratic_spline(inputs, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=False, + left=0., right=1., bottom=0., top=1., + min_bin_width=DEFAULT_MIN_BIN_WIDTH, + min_bin_height=DEFAULT_MIN_BIN_HEIGHT, + min_derivative=DEFAULT_MIN_DERIVATIVE): + if torch.min(inputs) < left or torch.max(inputs) > right: + raise ValueError('Input to a transform is not within its domain') + + num_bins = unnormalized_widths.shape[-1] + + if min_bin_width * num_bins > 1.0: + raise ValueError('Minimal bin width too large for the number of bins') + if min_bin_height * num_bins > 1.0: + raise ValueError('Minimal bin height too large for the number of bins') + + widths = F.softmax(unnormalized_widths, dim=-1) + widths = min_bin_width + (1 - min_bin_width * num_bins) * widths + cumwidths = torch.cumsum(widths, dim=-1) + cumwidths = F.pad(cumwidths, pad=(1, 0), mode='constant', value=0.0) + cumwidths = (right - left) * cumwidths + left + cumwidths[..., 0] = left + cumwidths[..., -1] = right + widths = cumwidths[..., 1:] - cumwidths[..., :-1] + + derivatives = min_derivative + F.softplus(unnormalized_derivatives) + + heights = F.softmax(unnormalized_heights, dim=-1) + heights = min_bin_height + (1 - min_bin_height * num_bins) * heights + cumheights = torch.cumsum(heights, dim=-1) + cumheights = F.pad(cumheights, pad=(1, 0), mode='constant', value=0.0) + cumheights = (top - bottom) * cumheights + bottom + cumheights[..., 0] = bottom + cumheights[..., -1] = top + heights = cumheights[..., 1:] - cumheights[..., :-1] + + if inverse: + bin_idx = searchsorted(cumheights, inputs)[..., None] + else: + bin_idx = searchsorted(cumwidths, inputs)[..., None] + + input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0] + input_bin_widths = widths.gather(-1, bin_idx)[..., 0] + + input_cumheights = cumheights.gather(-1, bin_idx)[..., 0] + delta = heights / widths + input_delta = delta.gather(-1, bin_idx)[..., 0] + + input_derivatives = derivatives.gather(-1, bin_idx)[..., 0] + input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0] + + input_heights = heights.gather(-1, bin_idx)[..., 0] + + if inverse: + a = (((inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta) + + input_heights * (input_delta - input_derivatives))) + b = (input_heights * input_derivatives + - (inputs - input_cumheights) * (input_derivatives + + input_derivatives_plus_one + - 2 * input_delta)) + c = - input_delta * (inputs - input_cumheights) + + discriminant = b.pow(2) - 4 * a * c + assert (discriminant >= 0).all() + + root = (2 * c) / (-b - torch.sqrt(discriminant)) + outputs = root * input_bin_widths + input_cumwidths + + theta_one_minus_theta = root * (1 - root) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * root.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - root).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, -logabsdet + else: + theta = (inputs - input_cumwidths) / input_bin_widths + theta_one_minus_theta = theta * (1 - theta) + + numerator = input_heights * (input_delta * theta.pow(2) + + input_derivatives * theta_one_minus_theta) + denominator = input_delta + ((input_derivatives + input_derivatives_plus_one - 2 * input_delta) + * theta_one_minus_theta) + outputs = input_cumheights + numerator / denominator + + derivative_numerator = input_delta.pow(2) * (input_derivatives_plus_one * theta.pow(2) + + 2 * input_delta * theta_one_minus_theta + + input_derivatives * (1 - theta).pow(2)) + logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator) + + return outputs, logabsdet diff --git a/models/synthesizer/models/sublayer/vits_modules.py b/models/synthesizer/models/sublayer/vits_modules.py new file mode 100644 index 0000000..c84d83f --- /dev/null +++ b/models/synthesizer/models/sublayer/vits_modules.py @@ -0,0 +1,675 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from torch.nn import Conv1d +from torch.nn.utils import weight_norm, remove_weight_norm +from utils.util import init_weights, get_padding, convert_pad_shape, convert_pad_shape, subsequent_mask, fused_add_tanh_sigmoid_multiply +from .common.transforms import piecewise_rational_quadratic_transform + +LRELU_SLOPE = 0.1 + +class LayerNorm(nn.Module): + def __init__(self, channels, eps=1e-5): + super().__init__() + self.channels = channels + self.eps = eps + + self.gamma = nn.Parameter(torch.ones(channels)) + self.beta = nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + x = x.transpose(1, -1) + x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) + return x.transpose(1, -1) + + +class ConvReluNorm(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, kernel_size, n_layers, p_dropout): + super().__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + assert n_layers > 1, "Number of layers should be larger than 0." + + self.conv_layers = nn.ModuleList() + self.norm_layers = nn.ModuleList() + self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.relu_drop = nn.Sequential( + nn.ReLU(), + nn.Dropout(p_dropout)) + for _ in range(n_layers-1): + self.conv_layers.append(nn.Conv1d(hidden_channels, hidden_channels, kernel_size, padding=kernel_size//2)) + self.norm_layers.append(LayerNorm(hidden_channels)) + self.proj = nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + x_org = x + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.norm_layers[i](x) + x = self.relu_drop(x) + x = x_org + self.proj(x) + return x * x_mask + + +class DDSConv(nn.Module): + """ + Dilated and Depth-Separable Convolution + """ + def __init__(self, channels, kernel_size, n_layers, p_dropout=0.): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + + self.drop = nn.Dropout(p_dropout) + self.convs_sep = nn.ModuleList() + self.convs_1x1 = nn.ModuleList() + self.norms_1 = nn.ModuleList() + self.norms_2 = nn.ModuleList() + for i in range(n_layers): + dilation = kernel_size ** i + padding = (kernel_size * dilation - dilation) // 2 + self.convs_sep.append(nn.Conv1d(channels, channels, kernel_size, + groups=channels, dilation=dilation, padding=padding + )) + self.convs_1x1.append(nn.Conv1d(channels, channels, 1)) + self.norms_1.append(LayerNorm(channels)) + self.norms_2.append(LayerNorm(channels)) + + def forward(self, x, x_mask, g=None): + if g is not None: + x = x + g + for i in range(self.n_layers): + y = self.convs_sep[i](x * x_mask) + y = self.norms_1[i](y) + y = F.gelu(y) + y = self.convs_1x1[i](y) + y = self.norms_2[i](y) + y = F.gelu(y) + y = self.drop(y) + x = x + y + return x * x_mask + + +class WN(torch.nn.Module): + def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + self.hidden_channels =hidden_channels + self.kernel_size = kernel_size, + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + self.p_dropout = p_dropout + + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + self.drop = nn.Dropout(p_dropout) + + if gin_channels != 0: + cond_layer = torch.nn.Conv1d(gin_channels, 2*hidden_channels*n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = dilation_rate ** i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d(hidden_channels, 2*hidden_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * hidden_channels + else: + res_skip_channels = hidden_channels + + res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, x, x_mask, g=None, **kwargs): + output = torch.zeros_like(x) + n_channels_tensor = torch.IntTensor([self.hidden_channels]) + + if g is not None: + g = self.cond_layer(g) + + for i in range(self.n_layers): + x_in = self.in_layers[i](x) + if g is not None: + cond_offset = i * 2 * self.hidden_channels + g_l = g[:,cond_offset:cond_offset+2*self.hidden_channels,:] + else: + g_l = torch.zeros_like(x_in) + + acts = fused_add_tanh_sigmoid_multiply( + x_in, + g_l, + n_channels_tensor) + acts = self.drop(acts) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + res_acts = res_skip_acts[:,:self.hidden_channels,:] + x = (x + res_acts) * x_mask + output = output + res_skip_acts[:,self.hidden_channels:,:] + else: + output = output + res_skip_acts + return output * x_mask + + def remove_weight_norm(self): + if self.gin_channels != 0: + torch.nn.utils.remove_weight_norm(self.cond_layer) + for l in self.in_layers: + torch.nn.utils.remove_weight_norm(l) + for l in self.res_skip_layers: + torch.nn.utils.remove_weight_norm(l) + + +class ResBlock1(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Log(nn.Module): + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask + logdet = torch.sum(-y, [1, 2]) + return y, logdet + else: + x = torch.exp(x) * x_mask + return x + + +class Flip(nn.Module): + def forward(self, x, *args, reverse=False, **kwargs): + x = torch.flip(x, [1]) + if not reverse: + logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device) + return x, logdet + else: + return x + + +class ElementwiseAffine(nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + self.m = nn.Parameter(torch.zeros(channels,1)) + self.logs = nn.Parameter(torch.zeros(channels,1)) + + def forward(self, x, x_mask, reverse=False, **kwargs): + if not reverse: + y = self.m + torch.exp(self.logs) * x + y = y * x_mask + logdet = torch.sum(self.logs * x_mask, [1,2]) + return y, logdet + else: + x = (x - self.m) * torch.exp(-self.logs) * x_mask + return x + + +class ResidualCouplingLayer(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + p_dropout=0, + gin_channels=0, + mean_only=False): + assert channels % 2 == 0, "channels should be divisible by 2" + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.half_channels = channels // 2 + self.mean_only = mean_only + + self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels) + self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1) + self.post.weight.data.zero_() + self.post.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) * x_mask + h = self.enc(h, x_mask, g=g) + stats = self.post(h) * x_mask + if not self.mean_only: + m, logs = torch.split(stats, [self.half_channels]*2, 1) + else: + m = stats + logs = torch.zeros_like(m) + + if not reverse: + x1 = m + x1 * torch.exp(logs) * x_mask + x = torch.cat([x0, x1], 1) + logdet = torch.sum(logs, [1,2]) + return x, logdet + else: + x1 = (x1 - m) * torch.exp(-logs) * x_mask + x = torch.cat([x0, x1], 1) + return x + + +class ConvFlow(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, n_layers, num_bins=10, tail_bound=5.0): + super().__init__() + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.num_bins = num_bins + self.tail_bound = tail_bound + self.half_channels = in_channels // 2 + + self.pre = nn.Conv1d(self.half_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.) + self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask, g=None, reverse=False): + x0, x1 = torch.split(x, [self.half_channels]*2, 1) + h = self.pre(x0) + h = self.convs(h, x_mask, g=g) + h = self.proj(h) * x_mask + + b, c, t = x0.shape + h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?] + + unnormalized_widths = h[..., :self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_heights = h[..., self.num_bins:2*self.num_bins] / math.sqrt(self.filter_channels) + unnormalized_derivatives = h[..., 2 * self.num_bins:] + + x1, logabsdet = piecewise_rational_quadratic_transform(x1, + unnormalized_widths, + unnormalized_heights, + unnormalized_derivatives, + inverse=reverse, + tails='linear', + tail_bound=self.tail_bound + ) + + x = torch.cat([x0, x1], 1) * x_mask + logdet = torch.sum(logabsdet * x_mask, [1,2]) + if not reverse: + return x, logdet + else: + return x + +class Encoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., window_size=4, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = nn.Dropout(p_dropout) + self.attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Decoder(nn.Module): + def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0., proximal_bias=False, proximal_init=True, **kwargs): + super().__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + + self.drop = nn.Dropout(p_dropout) + self.self_attn_layers = nn.ModuleList() + self.norm_layers_0 = nn.ModuleList() + self.encdec_attn_layers = nn.ModuleList() + self.norm_layers_1 = nn.ModuleList() + self.ffn_layers = nn.ModuleList() + self.norm_layers_2 = nn.ModuleList() + for i in range(self.n_layers): + self.self_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, proximal_bias=proximal_bias, proximal_init=proximal_init)) + self.norm_layers_0.append(LayerNorm(hidden_channels)) + self.encdec_attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout, causal=True)) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask, h, h_mask): + """ + x: decoder input + h: encoder output + """ + self_attn_mask = subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype) + encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + x = x * x_mask + for i in range(self.n_layers): + y = self.self_attn_layers[i](x, x, self_attn_mask) + y = self.drop(y) + x = self.norm_layers_0[i](x + y) + + y = self.encdec_attn_layers[i](x, h, encdec_attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, out_channels, n_heads, p_dropout=0., window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False): + super().__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.p_dropout = p_dropout + self.window_size = window_size + self.heads_share = heads_share + self.block_length = block_length + self.proximal_bias = proximal_bias + self.proximal_init = proximal_init + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = nn.Conv1d(channels, channels, 1) + self.conv_k = nn.Conv1d(channels, channels, 1) + self.conv_v = nn.Conv1d(channels, channels, 1) + self.conv_o = nn.Conv1d(channels, out_channels, 1) + self.drop = nn.Dropout(p_dropout) + + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev) + + nn.init.xavier_uniform_(self.conv_q.weight) + nn.init.xavier_uniform_(self.conv_k.weight) + nn.init.xavier_uniform_(self.conv_v.weight) + if proximal_init: + with torch.no_grad(): + self.conv_k.weight.copy_(self.conv_q.weight) + self.conv_k.bias.copy_(self.conv_q.bias) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + # reshape [b, d, t] -> [b, n_h, t, d_k] + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1)) + if self.window_size is not None: + assert t_s == t_t, "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query /math.sqrt(self.k_channels), key_relative_embeddings) + scores_local = self._relative_position_to_absolute_position(rel_logits) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + if self.block_length is not None: + assert t_s == t_t, "Local attention is only available for self-attention." + block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length) + scores = scores.masked_fill(block_mask == 0, -1e4) + p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s] + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s) + output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + """ + x: [b, h, l, m] + y: [h or 1, m, d] + ret: [b, h, l, d] + """ + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + """ + x: [b, h, l, d] + y: [h or 1, m, d] + ret: [b, h, l, m] + """ + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + max_relative_position = 2 * self.window_size + 1 + # Pad first before slice to avoid using cond ops. + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = F.pad( + relative_embeddings, + convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]])) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[:,slice_start_position:slice_end_position] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + """ + x: [b, h, l, 2*l-1] + ret: [b, h, l, l] + """ + batch, heads, length, _ = x.size() + # Concat columns of pad to shift from relative to absolute indexing. + x = F.pad(x, convert_pad_shape([[0,0],[0,0],[0,0],[0,1]])) + + # Concat extra elements so to add up to shape (len+1, 2*len-1). + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = F.pad(x_flat, convert_pad_shape([[0,0],[0,0],[0,length-1]])) + + # Reshape and slice out the padded elements. + x_final = x_flat.view([batch, heads, length+1, 2*length-1])[:, :, :length, length-1:] + return x_final + + def _absolute_position_to_relative_position(self, x): + """ + x: [b, h, l, l] + ret: [b, h, l, 2*l-1] + """ + batch, heads, length, _ = x.size() + # padd along column + x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length-1]])) + x_flat = x.view([batch, heads, length**2 + length*(length -1)]) + # add 0's in the beginning that will skew the elements after reshape + x_flat = F.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])) + x_final = x_flat.view([batch, heads, length, 2*length])[:,:,:,1:] + return x_final + + def _attention_bias_proximal(self, length): + """Bias for self-attention to encourage attention to close positions. + Args: + length: an integer scalar. + Returns: + a Tensor with shape [1, 1, length, length] + """ + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(nn.Module): + def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., activation=None, causal=False): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.activation = activation + self.causal = causal + + if causal: + self.padding = self._causal_padding + else: + self.padding = self._same_padding + + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size) + self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size) + self.drop = nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(self.padding(x * x_mask)) + if self.activation == "gelu": + x = x * torch.sigmoid(1.702 * x) + else: + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(self.padding(x * x_mask)) + return x * x_mask + + def _causal_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = self.kernel_size - 1 + pad_r = 0 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x + + def _same_padding(self, x): + if self.kernel_size == 1: + return x + pad_l = (self.kernel_size - 1) // 2 + pad_r = self.kernel_size // 2 + padding = [[0, 0], [0, 0], [pad_l, pad_r]] + x = F.pad(x, convert_pad_shape(padding)) + return x diff --git a/models/synthesizer/models/vits.py b/models/synthesizer/models/vits.py new file mode 100644 index 0000000..db4a917 --- /dev/null +++ b/models/synthesizer/models/vits.py @@ -0,0 +1,524 @@ +import math +import torch +from torch import nn +from torch.nn import functional as F + +from .sublayer.vits_modules import * +import monotonic_align + +from .base import Base +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from utils.util import init_weights, get_padding, sequence_mask, rand_slice_segments, generate_path + + +class StochasticDurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, n_flows=4, gin_channels=0): + super().__init__() + filter_channels = in_channels # it needs to be removed from future version. + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.log_flow = Log() + self.flows = nn.ModuleList() + self.flows.append(ElementwiseAffine(2)) + for i in range(n_flows): + self.flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.flows.append(Flip()) + + self.post_pre = nn.Conv1d(1, filter_channels, 1) + self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.post_convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + self.post_flows = nn.ModuleList() + self.post_flows.append(ElementwiseAffine(2)) + for i in range(4): + self.post_flows.append(ConvFlow(2, filter_channels, kernel_size, n_layers=3)) + self.post_flows.append(Flip()) + + self.pre = nn.Conv1d(in_channels, filter_channels, 1) + self.proj = nn.Conv1d(filter_channels, filter_channels, 1) + self.convs = DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout) + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, filter_channels, 1) + + def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0): + x = torch.detach(x) + x = self.pre(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.convs(x, x_mask) + x = self.proj(x) * x_mask + + if not reverse: + flows = self.flows + assert w is not None + + logdet_tot_q = 0 + h_w = self.post_pre(w) + h_w = self.post_convs(h_w, x_mask) + h_w = self.post_proj(h_w) * x_mask + e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask + z_q = e_q + for flow in self.post_flows: + z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w)) + logdet_tot_q += logdet_q + z_u, z1 = torch.split(z_q, [1, 1], 1) + u = torch.sigmoid(z_u) * x_mask + z0 = (w - u) * x_mask + logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2]) + logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q + + logdet_tot = 0 + z0, logdet = self.log_flow(z0, x_mask) + logdet_tot += logdet + z = torch.cat([z0, z1], 1) + for flow in flows: + z, logdet = flow(z, x_mask, g=x, reverse=reverse) + logdet_tot = logdet_tot + logdet + nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot + return nll + logq # [b] + else: + flows = list(reversed(self.flows)) + flows = flows[:-2] + [flows[-1]] # remove a useless vflow + z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale + for flow in flows: + z = flow(z, x_mask, g=x, reverse=reverse) + z0, z1 = torch.split(z, [1, 1], 1) + logw = z0 + return logw + + +class DurationPredictor(nn.Module): + def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0): + super().__init__() + + self.in_channels = in_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.gin_channels = gin_channels + + self.drop = nn.Dropout(p_dropout) + self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_1 = LayerNorm(filter_channels) + self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size//2) + self.norm_2 = LayerNorm(filter_channels) + self.proj = nn.Conv1d(filter_channels, 1, 1) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, in_channels, 1) + + def forward(self, x, x_mask, g=None): + x = torch.detach(x) + if g is not None: + g = torch.detach(g) + x = x + self.cond(g) + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.norm_1(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + x = torch.relu(x) + x = self.norm_2(x) + x = self.drop(x) + x = self.proj(x * x_mask) + return x * x_mask + + +class TextEncoder(nn.Module): + def __init__(self, + n_vocab, + out_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout): + super().__init__() + self.n_vocab = n_vocab + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.emb = nn.Embedding(n_vocab, hidden_channels) + self.emo_proj = nn.Linear(1024, hidden_channels) + + nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) + + self.encoder = Encoder( + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.proj= nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, emo): + x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] + x = x + self.emo_proj(emo.unsqueeze(1)) + x = torch.transpose(x, 1, -1) # [b, h, t] + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + + x = self.encoder(x * x_mask, x_mask) + stats = self.proj(x) * x_mask + + m, logs = torch.split(stats, self.out_channels, dim=1) + return x, m, logs, x_mask + + +class ResidualCouplingBlock(nn.Module): + def __init__(self, + channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + n_flows=4, + gin_channels=0): + super().__init__() + self.channels = channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.n_flows = n_flows + self.gin_channels = gin_channels + + self.flows = nn.ModuleList() + for i in range(n_flows): + self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True)) + self.flows.append(Flip()) + + def forward(self, x, x_mask, g=None, reverse=False): + if not reverse: + for flow in self.flows: + x, _ = flow(x, x_mask, g=g, reverse=reverse) + else: + for flow in reversed(self.flows): + x = flow(x, x_mask, g=g, reverse=reverse) + return x + + +class PosteriorEncoder(nn.Module): + def __init__(self, + in_channels, + out_channels, + hidden_channels, + kernel_size, + dilation_rate, + n_layers, + gin_channels=0): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.hidden_channels = hidden_channels + self.kernel_size = kernel_size + self.dilation_rate = dilation_rate + self.n_layers = n_layers + self.gin_channels = gin_channels + + self.pre = nn.Conv1d(in_channels, hidden_channels, 1) + self.enc = WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels) + self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) + + def forward(self, x, x_lengths, g=None): + x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype) + x = self.pre(x) * x_mask + x = self.enc(x, x_mask, g=g) + stats = self.proj(x) * x_mask + m, logs = torch.split(stats, self.out_channels, dim=1) + z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask + return z, m, logs, x_mask + + +class Generator(torch.nn.Module): + def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0): + super(Generator, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3) + resblock = ResBlock1 if resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append(weight_norm( + ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)), + k, u, padding=(k-u)//2))) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel//(2**(i+1)) + for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i*self.num_kernels+j](x) + else: + xs += self.resblocks[i*self.num_kernels+j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(get_padding(kernel_size, 1), 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator, self).__init__() + periods = [2,3,5,7,11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + +class Vits(Base): + """ + Synthesizer of Vits + """ + + def __init__(self, + n_vocab, + spec_channels, + segment_size, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + stop_threshold, + n_speakers=0, + gin_channels=0, + use_sdp=True, + **kwargs): + + super().__init__(stop_threshold) + self.n_vocab = n_vocab + self.spec_channels = spec_channels + self.inter_channels = inter_channels + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.resblock = resblock + self.resblock_kernel_sizes = resblock_kernel_sizes + self.resblock_dilation_sizes = resblock_dilation_sizes + self.upsample_rates = upsample_rates + self.upsample_initial_channel = upsample_initial_channel + self.upsample_kernel_sizes = upsample_kernel_sizes + self.segment_size = segment_size + self.n_speakers = n_speakers + self.gin_channels = gin_channels + + self.use_sdp = use_sdp + + self.enc_p = TextEncoder(n_vocab, + inter_channels, + hidden_channels, + filter_channels, + n_heads, + n_layers, + kernel_size, + p_dropout) + self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels) + self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels) + self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels) + + if use_sdp: + self.dp = StochasticDurationPredictor(hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels) + else: + self.dp = DurationPredictor(hidden_channels, 256, 3, 0.5, gin_channels=gin_channels) + + if n_speakers > 1: + self.emb_g = nn.Embedding(n_speakers, gin_channels) + + def forward(self, x, x_lengths, y, y_lengths, sid=None, emo=None): + + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, emo) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g) + z_p = self.flow(z, y_mask, g=g) + + with torch.no_grad(): + # negative cross-entropy + s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t] + neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True) # [b, 1, t_s] + neg_cent2 = torch.matmul(-0.5 * (z_p ** 2).transpose(1, 2), s_p_sq_r) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent3 = torch.matmul(z_p.transpose(1, 2), (m_p * s_p_sq_r)) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s] + neg_cent4 = torch.sum(-0.5 * (m_p ** 2) * s_p_sq_r, [1], keepdim=True) # [b, 1, t_s] + neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4 + + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach() + + w = attn.sum(2) + if self.use_sdp: + l_length = self.dp(x, x_mask, w, g=g) + l_length = l_length / torch.sum(x_mask) + else: + logw_ = torch.log(w + 1e-6) * x_mask + logw = self.dp(x, x_mask, g=g) + l_length = torch.sum((logw - logw_)**2, [1,2]) / torch.sum(x_mask) # for averaging + + # expand prior + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) + + z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size) + o = self.dec(z_slice, g=g) + return o, l_length, attn, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q) + + def infer(self, x, x_lengths, sid=None, emo=None, noise_scale=1, length_scale=1, noise_scale_w=1., max_len=None): + x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths,emo) + if self.n_speakers > 0: + g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1] + else: + g = None + + if self.use_sdp: + logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) + else: + logw = self.dp(x, x_mask, g=g) + w = torch.exp(logw) * x_mask * length_scale + w_ceil = torch.ceil(w) + y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long() + y_mask = torch.unsqueeze(sequence_mask(y_lengths, None), 1).to(x_mask.dtype) + attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1) + attn = generate_path(w_ceil, attn_mask) + + m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2) # [b, t', t], [b, t, d] -> [b, d, t'] + + z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale + z = self.flow(z_p, y_mask, g=g, reverse=True) + o = self.dec((z * y_mask)[:,:,:max_len], g=g) + return o, attn, y_mask, (z, z_p, m_p, logs_p) + diff --git a/models/synthesizer/models/wav2emo.py b/models/synthesizer/models/wav2emo.py new file mode 100644 index 0000000..6760ccb --- /dev/null +++ b/models/synthesizer/models/wav2emo.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from transformers.models.wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Model, + Wav2Vec2PreTrainedModel, +) + + +class RegressionHead(nn.Module): + r"""Classification head.""" + + def __init__(self, config): + super().__init__() + + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.dropout = nn.Dropout(config.final_dropout) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features + x = self.dropout(x) + x = self.dense(x) + x = torch.tanh(x) + x = self.dropout(x) + x = self.out_proj(x) + + return x + + +class EmotionExtractorModel(Wav2Vec2PreTrainedModel): + r"""Speech emotion classifier.""" + + def __init__(self, config): + super().__init__(config) + + self.config = config + self.wav2vec2 = Wav2Vec2Model(config) + self.classifier = RegressionHead(config) + self.init_weights() + + def forward( + self, + input_values, + ): + outputs = self.wav2vec2(input_values) + hidden_states = outputs[0] + hidden_states = torch.mean(hidden_states, dim=1) + logits = self.classifier(hidden_states) + + return hidden_states, logits diff --git a/models/synthesizer/preprocess.py b/models/synthesizer/preprocess.py index 5299781..bdc98a5 100644 --- a/models/synthesizer/preprocess.py +++ b/models/synthesizer/preprocess.py @@ -6,37 +6,42 @@ from pathlib import Path from tqdm import tqdm import numpy as np from models.encoder import inference as encoder -from models.synthesizer.preprocess_speaker import preprocess_speaker_general +from models.synthesizer.preprocess_audio import preprocess_general from models.synthesizer.preprocess_transcript import preprocess_transcript_aishell3, preprocess_transcript_magicdata data_info = { "aidatatang_200zh": { "subfolders": ["corpus/train"], "trans_filepath": "transcript/aidatatang_200_zh_transcript.txt", - "speak_func": preprocess_speaker_general + "speak_func": preprocess_general + }, + "aidatatang_200zh_s": { + "subfolders": ["corpus/train"], + "trans_filepath": "transcript/aidatatang_200_zh_transcript.txt", + "speak_func": preprocess_general }, "magicdata": { "subfolders": ["train"], "trans_filepath": "train/TRANS.txt", - "speak_func": preprocess_speaker_general, + "speak_func": preprocess_general, "transcript_func": preprocess_transcript_magicdata, }, "aishell3":{ "subfolders": ["train/wav"], "trans_filepath": "train/content.txt", - "speak_func": preprocess_speaker_general, + "speak_func": preprocess_general, "transcript_func": preprocess_transcript_aishell3, }, "data_aishell":{ "subfolders": ["wav/train"], "trans_filepath": "transcript/aishell_transcript_v0.8.txt", - "speak_func": preprocess_speaker_general + "speak_func": preprocess_general } } def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, - skip_existing: bool, hparams, no_alignments: bool, - dataset: str): + skip_existing: bool, hparams, no_alignments: bool, + dataset: str, emotion_extract = False): dataset_info = data_info[dataset] # Gather the input directories dataset_root = datasets_root.joinpath(dataset) @@ -47,6 +52,8 @@ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, # Create the output directories for each output file type out_dir.joinpath("mels").mkdir(exist_ok=True) out_dir.joinpath("audio").mkdir(exist_ok=True) + if emotion_extract: + out_dir.joinpath("emo").mkdir(exist_ok=True) # Create a metadata file metadata_fpath = out_dir.joinpath("train.txt") @@ -68,12 +75,15 @@ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, dict_info[v[0]] = " ".join(v[1:]) speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs)) + func = partial(dataset_info["speak_func"], out_dir=out_dir, skip_existing=skip_existing, - hparams=hparams, dict_info=dict_info, no_alignments=no_alignments) + hparams=hparams, dict_info=dict_info, no_alignments=no_alignments, emotion_extract=emotion_extract) job = Pool(n_processes).imap(func, speaker_dirs) + for speaker_metadata in tqdm(job, dataset, len(speaker_dirs), unit="speakers"): - for metadatum in speaker_metadata: - metadata_file.write("|".join(str(x) for x in metadatum) + "\n") + if speaker_metadata is not None: + for metadatum in speaker_metadata: + metadata_file.write("|".join(str(x) for x in metadatum) + "\n") metadata_file.close() # Verify the contents of the metadata file diff --git a/models/synthesizer/preprocess_speaker.py b/models/synthesizer/preprocess_audio.py similarity index 72% rename from models/synthesizer/preprocess_speaker.py rename to models/synthesizer/preprocess_audio.py index fcd829f..c8f7904 100644 --- a/models/synthesizer/preprocess_speaker.py +++ b/models/synthesizer/preprocess_audio.py @@ -9,6 +9,38 @@ from pypinyin import Style from pypinyin.contrib.neutral_tone import NeutralToneWith5Mixin from pypinyin.converter import DefaultConverter from pypinyin.core import Pinyin +import torch +from transformers import Wav2Vec2Processor +from .models.wav2emo import EmotionExtractorModel + +SAMPLE_RATE = 16000 + +# load model from hub +device = 'cuda' if torch.cuda.is_available() else "cpu" +model_name = 'audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim' +processor = Wav2Vec2Processor.from_pretrained(model_name) +model = EmotionExtractorModel.from_pretrained(model_name).to(device) +embs = [] +wavnames = [] + +def extract_emo( + x: np.ndarray, + sampling_rate: int, + embeddings: bool = False, +) -> np.ndarray: + r"""Predict emotions or extract embeddings from raw audio signal.""" + y = processor(x, sampling_rate=sampling_rate) + y = y['input_values'][0] + y = torch.from_numpy(y).to(device) + + # run through model + with torch.no_grad(): + y = model(y)[0 if embeddings else 1] + + # convert to numpy + y = y.detach().cpu().numpy() + + return y class PinyinConverter(NeutralToneWith5Mixin, DefaultConverter): pass @@ -16,8 +48,10 @@ class PinyinConverter(NeutralToneWith5Mixin, DefaultConverter): pinyin = Pinyin(PinyinConverter()).pinyin + + def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str, - skip_existing: bool, hparams): + skip_existing: bool, hparams, emotion_extract: bool): ## FOR REFERENCE: # For you not to lose your head if you ever wish to change things here or implement your own # synthesizer. @@ -29,12 +63,13 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str, # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved # without extra padding. This means that you won't have an exact relation between the length # of the wav and of the mel spectrogram. See the vocoder data loader. - - + # Skip existing utterances if needed mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename) wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename) - if skip_existing and mel_fpath.exists() and wav_fpath.exists(): + emo_fpath = out_dir.joinpath("emo", "emo-%s.npy" % basename) + skip_emo_extract = not emotion_extract or (skip_existing and emo_fpath.exists()) + if skip_existing and mel_fpath.exists() and wav_fpath.exists() and skip_emo_extract: return None # Trim silence @@ -52,11 +87,14 @@ def _process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str, # Skip utterances that are too long if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length: return None - # Write the spectrogram, embed and audio to disk np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False) np.save(wav_fpath, wav, allow_pickle=False) - + + if not skip_emo_extract: + emo = extract_emo(np.expand_dims(wav, 0), hparams.sample_rate, True) + np.save(emo_fpath, emo, allow_pickle=False) + # Return a tuple describing this training example return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text @@ -80,7 +118,7 @@ def _split_on_silences(wav_fpath, words, hparams): return wav, res -def preprocess_speaker_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams, dict_info, no_alignments: bool): +def preprocess_general(speaker_dir, out_dir: Path, skip_existing: bool, hparams, dict_info, no_alignments: bool, emotion_extract: bool): metadata = [] extensions = ["*.wav", "*.flac", "*.mp3"] for extension in extensions: @@ -88,12 +126,12 @@ def preprocess_speaker_general(speaker_dir, out_dir: Path, skip_existing: bool, # Iterate over each wav for wav_fpath in wav_fpath_list: words = dict_info.get(wav_fpath.name.split(".")[0]) - words = dict_info.get(wav_fpath.name) if not words else words # try with wav + words = dict_info.get(wav_fpath.name) if not words else words # try with extension if not words: print("no wordS") continue sub_basename = "%s_%02d" % (wav_fpath.name, 0) wav, text = _split_on_silences(wav_fpath, words, hparams) metadata.append(_process_utterance(wav, text, out_dir, sub_basename, - skip_existing, hparams)) + skip_existing, hparams, emotion_extract)) return [m for m in metadata if m is not None] diff --git a/models/synthesizer/synthesize.py b/models/synthesizer/synthesize.py index 8c70b0f..7dc18b3 100644 --- a/models/synthesizer/synthesize.py +++ b/models/synthesizer/synthesize.py @@ -2,7 +2,6 @@ import torch from torch.utils.data import DataLoader from models.synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer from models.synthesizer.models.tacotron import Tacotron -from models.synthesizer.utils.text import text_to_sequence from models.synthesizer.utils.symbols import symbols import numpy as np from pathlib import Path diff --git a/models/synthesizer/train.py b/models/synthesizer/train.py index 21e3961..fc72a67 100644 --- a/models/synthesizer/train.py +++ b/models/synthesizer/train.py @@ -78,8 +78,7 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, # Try to scan config file model_config_fpaths = list(weights_fpath.parent.rglob("*.json")) if len(model_config_fpaths)>0 and model_config_fpaths[0].exists(): - with model_config_fpaths[0].open("r", encoding="utf-8") as f: - hparams.loadJson(json.load(f)) + hparams.loadJson(model_config_fpaths[0]) else: # save a config hparams.dumpJson(weights_fpath.parent.joinpath(run_id).with_suffix(".json")) diff --git a/models/synthesizer/train_vits.py b/models/synthesizer/train_vits.py new file mode 100644 index 0000000..d8324d9 --- /dev/null +++ b/models/synthesizer/train_vits.py @@ -0,0 +1,389 @@ +import os +from loguru import logger +import torch +import glob +from torch.nn import functional as F +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.cuda.amp import autocast, GradScaler +from utils.audio_utils import mel_spectrogram, spec_to_mel +from utils.loss import feature_loss, generator_loss, discriminator_loss, kl_loss +from utils.util import slice_segments, clip_grad_value_ +from models.synthesizer.vits_dataset import ( + VitsDataset, + VitsDatasetCollate, + DistributedBucketSampler +) +from models.synthesizer.models.vits import ( + Vits, + MultiPeriodDiscriminator, +) +from models.synthesizer.utils.symbols import symbols +from models.synthesizer.utils.plot import plot_spectrogram_to_numpy, plot_alignment_to_numpy +from pathlib import Path +from utils.hparams import HParams +import torch.multiprocessing as mp +import argparse + +# torch.backends.cudnn.benchmark = True +global_step = 0 + + +def new_train(): + """Assume Single Node Multi GPUs Training Only""" + assert torch.cuda.is_available(), "CPU training is not allowed." + + parser = argparse.ArgumentParser() + parser.add_argument("--syn_dir", type=str, default="../audiodata/SV2TTS/synthesizer", help= \ + "Path to the synthesizer directory that contains the ground truth mel spectrograms, " + "the wavs, the emos and the embeds.") + parser.add_argument("-m", "--model_dir", type=str, default="data/ckpt/synthesizer/vits", help=\ + "Path to the output directory that will contain the saved model weights and the logs.") + parser.add_argument('--ckptG', type=str, required=False, + help='original VITS G checkpoint path') + parser.add_argument('--ckptD', type=str, required=False, + help='original VITS D checkpoint path') + args, _ = parser.parse_known_args() + + datasets_root = Path(args.syn_dir) + hparams= HParams( + model_dir = args.model_dir, + ) + hparams.loadJson(Path(hparams.model_dir).joinpath("config.json")) + hparams.data["training_files"] = str(datasets_root.joinpath("train.txt")) + hparams.data["validation_files"] = str(datasets_root.joinpath("train.txt")) + hparams.data["datasets_root"] = str(datasets_root) + hparams.ckptG = args.ckptG + hparams.ckptD = args.ckptD + n_gpus = torch.cuda.device_count() + # for spawn + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '8899' + # mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hparams)) + run(0, 1, hparams) + + +def load_checkpoint(checkpoint_path, model, optimizer=None, is_old=False): + assert os.path.isfile(checkpoint_path) + checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') + iteration = checkpoint_dict['iteration'] + learning_rate = checkpoint_dict['learning_rate'] + if optimizer is not None: + if not is_old: + optimizer.load_state_dict(checkpoint_dict['optimizer']) + else: + new_opt_dict = optimizer.state_dict() + new_opt_dict_params = new_opt_dict['param_groups'][0]['params'] + new_opt_dict['param_groups'] = checkpoint_dict['optimizer']['param_groups'] + new_opt_dict['param_groups'][0]['params'] = new_opt_dict_params + optimizer.load_state_dict(new_opt_dict) + saved_state_dict = checkpoint_dict['model'] + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + new_state_dict= {} + for k, v in state_dict.items(): + try: + new_state_dict[k] = saved_state_dict[k] + except: + logger.info("%s is not in the checkpoint" % k) + new_state_dict[k] = v + if hasattr(model, 'module'): + model.module.load_state_dict(new_state_dict, strict=False) + else: + model.load_state_dict(new_state_dict, strict=False) + logger.info("Loaded checkpoint '{}' (iteration {})" .format( + checkpoint_path, iteration)) + return model, optimizer, learning_rate, iteration + + +def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): + logger.info("Saving model and optimizer state at iteration {} to {}".format( + iteration, checkpoint_path)) + if hasattr(model, 'module'): + state_dict = model.module.state_dict() + else: + state_dict = model.state_dict() + torch.save({'model': state_dict, + 'iteration': iteration, + 'optimizer': optimizer.state_dict(), + 'learning_rate': learning_rate}, checkpoint_path) + +def latest_checkpoint_path(dir_path, regex="G_*.pth"): + f_list = glob.glob(os.path.join(dir_path, regex)) + f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f)))) + x = f_list[-1] + print(x) + return x + +def run(rank, n_gpus, hps): + global global_step + if rank == 0: + logger.info(hps) + writer = SummaryWriter(log_dir=hps.model_dir) + writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval")) + + dist.init_process_group(backend='gloo', init_method='env://', world_size=n_gpus, rank=rank) + torch.manual_seed(hps.train.seed) + torch.cuda.set_device(rank) + train_dataset = VitsDataset(hps.data.training_files, hps.data) + train_sampler = DistributedBucketSampler( + train_dataset, + hps.train.batch_size, + [32, 300, 400, 500, 600, 700, 800, 900, 1000], + num_replicas=n_gpus, + rank=rank, + shuffle=True) + collate_fn = VitsDatasetCollate() + train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, + collate_fn=collate_fn, batch_sampler=train_sampler) + if rank == 0: + eval_dataset = VitsDataset(hps.data.validation_files, hps.data) + eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False, + batch_size=hps.train.batch_size, pin_memory=True, + drop_last=False, collate_fn=collate_fn) + + net_g = Vits( + len(symbols), + hps.data.filter_length // 2 + 1, + hps.train.segment_size // hps.data.hop_length, + n_speakers=hps.data.n_speakers, + **hps.model).cuda(rank) + net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank) + optim_g = torch.optim.AdamW( + net_g.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps) + optim_d = torch.optim.AdamW( + net_d.parameters(), + hps.train.learning_rate, + betas=hps.train.betas, + eps=hps.train.eps) + net_g = DDP(net_g, device_ids=[rank]) + net_d = DDP(net_d, device_ids=[rank]) + ckptG = hps.ckptG + ckptD = hps.ckptD + try: + if ckptG is not None: + _, _, _, epoch_str = load_checkpoint(ckptG, net_g, optim_g, is_old=True) + print("加载原版VITS模型G记录点成功") + else: + _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, + optim_g) + if ckptD is not None: + _, _, _, epoch_str = load_checkpoint(ckptG, net_g, optim_g, is_old=True) + print("加载原版VITS模型D记录点成功") + else: + _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, + optim_d) + global_step = (epoch_str - 1) * len(train_loader) + except: + epoch_str = 1 + global_step = 0 + if ckptG is not None or ckptD is not None: + epoch_str = 1 + global_step = 0 + scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2) + + scaler = GradScaler(enabled=hps.train.fp16_run) + + for epoch in range(epoch_str, hps.train.epochs + 1): + if rank == 0: + train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, + [train_loader, eval_loader], logger, [writer, writer_eval]) + else: + train_and_evaluate(rank, epoch, hps, [net_g, net_d], [optim_g, optim_d], [scheduler_g, scheduler_d], scaler, + [train_loader, None], None, None) + scheduler_g.step() + scheduler_d.step() + + +def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers): + net_g, net_d = nets + optim_g, optim_d = optims + scheduler_g, scheduler_d = schedulers + train_loader, eval_loader = loaders + if writers is not None: + writer, writer_eval = writers + train_loader.batch_sampler.set_epoch(epoch) + global global_step + + net_g.train() + net_d.train() + for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers, emo) in enumerate(train_loader): + logger.info(f'====> Step: 1 {batch_idx}') + x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(rank, non_blocking=True) + spec, spec_lengths = spec.cuda(rank, non_blocking=True), spec_lengths.cuda(rank, non_blocking=True) + y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(rank, non_blocking=True) + speakers = speakers.cuda(rank, non_blocking=True) + emo = emo.cuda(rank, non_blocking=True) + + with autocast(enabled=hps.train.fp16_run): + y_hat, l_length, attn, ids_slice, x_mask, z_mask, \ + (z, z_p, m_p, logs_p, m_q, logs_q) = net_g(x, x_lengths, spec, spec_lengths, speakers, emo) + + mel = spec_to_mel( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax) + y_mel = slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) + y_hat_mel = mel_spectrogram( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax + ) + + y = slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + + # Discriminator + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc_all = loss_disc + optim_d.zero_grad() + scaler.scale(loss_disc_all).backward() + scaler.unscale_(optim_d) + grad_norm_d = clip_grad_value_(net_d.parameters(), None) + scaler.step(optim_d) + logger.info(f'====> Step: 2 {batch_idx}') + + with autocast(enabled=hps.train.fp16_run): + # Generator + y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(y, y_hat) + with autocast(enabled=False): + loss_dur = torch.sum(l_length.float()) + loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel + loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl + + loss_fm = feature_loss(fmap_r, fmap_g) + loss_gen, losses_gen = generator_loss(y_d_hat_g) + loss_gen_all = loss_gen + loss_fm + loss_mel + loss_dur + loss_kl + optim_g.zero_grad() + scaler.scale(loss_gen_all.float()).backward() + scaler.unscale_(optim_g) + grad_norm_g = clip_grad_value_(net_g.parameters(), None) + scaler.step(optim_g) + scaler.update() + # logger.info(f'====> Step: 3 {batch_idx}') + if rank == 0: + if global_step % hps.train.log_interval == 0: + lr = optim_g.param_groups[0]['lr'] + losses = [loss_disc, loss_gen, loss_fm, loss_mel, loss_dur, loss_kl] + logger.info('Train Epoch: {} [{:.0f}%]'.format( + epoch, + 100. * batch_idx / len(train_loader))) + logger.info([x.item() for x in losses] + [global_step, lr]) + + scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc_all, "learning_rate": lr, + "grad_norm_d": grad_norm_d, "grad_norm_g": grad_norm_g} + scalar_dict.update( + {"loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/dur": loss_dur, "loss/g/kl": loss_kl}) + + scalar_dict.update({"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}) + scalar_dict.update({"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}) + scalar_dict.update({"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}) + image_dict = { + "slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), + "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), + "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()), + "all/attn": plot_alignment_to_numpy(attn[0, 0].data.cpu().numpy()) + } + summarize( + writer=writer, + global_step=global_step, + images=image_dict, + scalars=scalar_dict) + + if global_step % hps.train.eval_interval == 0: + evaluate(hps, net_g, eval_loader, writer_eval) + save_checkpoint(net_g, optim_g, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "G_{}.pth".format(global_step))) + save_checkpoint(net_d, optim_d, hps.train.learning_rate, epoch, + os.path.join(hps.model_dir, "D_{}.pth".format(global_step))) + global_step += 1 + + if rank == 0: + logger.info('====> Epoch: {}'.format(epoch)) + + +def evaluate(hps, generator, eval_loader, writer_eval): + generator.eval() + with torch.no_grad(): + for batch_idx, (x, x_lengths, spec, spec_lengths, y, y_lengths, speakers, emo) in enumerate(eval_loader): + x, x_lengths = x.cuda(0), x_lengths.cuda(0) + spec, spec_lengths = spec.cuda(0), spec_lengths.cuda(0) + y, y_lengths = y.cuda(0), y_lengths.cuda(0) + speakers = speakers.cuda(0) + emo = emo.cuda(0) + # remove else + x = x[:1] + x_lengths = x_lengths[:1] + spec = spec[:1] + spec_lengths = spec_lengths[:1] + y = y[:1] + y_lengths = y_lengths[:1] + speakers = speakers[:1] + emo = emo[:1] + break + y_hat, attn, mask, *_ = generator.module.infer(x, x_lengths, speakers, emo, max_len=1000) + y_hat_lengths = mask.sum([1, 2]).long() * hps.data.hop_length + + mel = spec_to_mel( + spec, + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.mel_fmin, + hps.data.mel_fmax) + y_hat_mel = mel_spectrogram( + y_hat.squeeze(1).float(), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax + ) + image_dict = { + "gen/mel": plot_spectrogram_to_numpy(y_hat_mel[0].cpu().numpy()) + } + audio_dict = { + "gen/audio": y_hat[0, :, :y_hat_lengths[0]] + } + if global_step == 0: + image_dict.update({"gt/mel": plot_spectrogram_to_numpy(mel[0].cpu().numpy())}) + audio_dict.update({"gt/audio": y[0, :, :y_lengths[0]]}) + + summarize( + writer=writer_eval, + global_step=global_step, + images=image_dict, + audios=audio_dict, + audio_sampling_rate=hps.data.sampling_rate + ) + generator.train() + +def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050): + for k, v in scalars.items(): + writer.add_scalar(k, v, global_step) + for k, v in histograms.items(): + writer.add_histogram(k, v, global_step) + for k, v in images.items(): + writer.add_image(k, v, global_step, dataformats='HWC') + for k, v in audios.items(): + writer.add_audio(k, v, global_step, audio_sampling_rate) + diff --git a/models/synthesizer/utils/plot.py b/models/synthesizer/utils/plot.py index efdb567..355c478 100644 --- a/models/synthesizer/utils/plot.py +++ b/models/synthesizer/utils/plot.py @@ -3,6 +3,7 @@ matplotlib.use("Agg") import matplotlib.pyplot as plt import numpy as np +MATPLOTLIB_FLAG = False def split_title_line(title_text, max_words=5): """ @@ -112,4 +113,55 @@ def plot_spectrogram_and_trace(pred_spectrogram, path, title=None, split_title=F plt.tight_layout() plt.savefig(path, format="png") sw.add_figure("spectrogram", fig, step) - plt.close() \ No newline at end of file + plt.close() + + +def plot_spectrogram_to_numpy(spectrogram): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(10,2)) + im = ax.imshow(spectrogram, aspect="auto", origin="lower", + interpolation='none') + plt.colorbar(im, ax=ax) + plt.xlabel("Frames") + plt.ylabel("Channels") + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data + + +def plot_alignment_to_numpy(alignment, info=None): + global MATPLOTLIB_FLAG + if not MATPLOTLIB_FLAG: + import matplotlib + matplotlib.use("Agg") + MATPLOTLIB_FLAG = True + import matplotlib.pylab as plt + import numpy as np + + fig, ax = plt.subplots(figsize=(6, 4)) + im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower', + interpolation='none') + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + plt.tight_layout() + + fig.canvas.draw() + data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') + data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) + plt.close() + return data diff --git a/models/synthesizer/vits_dataset.py b/models/synthesizer/vits_dataset.py new file mode 100644 index 0000000..32702d1 --- /dev/null +++ b/models/synthesizer/vits_dataset.py @@ -0,0 +1,280 @@ +import os +import random +import numpy as np +import torch +import torch.utils.data + +from utils.audio_utils import spectrogram, load_wav +from utils.util import intersperse +from models.synthesizer.utils.text import text_to_sequence + + +"""Multi speaker version""" +class VitsDataset(torch.utils.data.Dataset): + """ + 1) loads audio, speaker_id, text pairs + 2) normalizes text and converts them to sequences of integers + 3) computes spectrograms from audio files. + """ + def __init__(self, audio_file_path, hparams): + with open(audio_file_path, encoding='utf-8') as f: + self.audio_metadata = [line.strip().split('|') for line in f] + self.text_cleaners = hparams.text_cleaners + self.max_wav_value = hparams.max_wav_value + self.sampling_rate = hparams.sampling_rate + self.filter_length = hparams.filter_length + self.hop_length = hparams.hop_length + self.win_length = hparams.win_length + self.sampling_rate = hparams.sampling_rate + + self.cleaned_text = getattr(hparams, "cleaned_text", False) + + self.add_blank = hparams.add_blank + self.datasets_root = hparams.datasets_root + + self.min_text_len = getattr(hparams, "min_text_len", 1) + self.max_text_len = getattr(hparams, "max_text_len", 190) + + random.seed(1234) + random.shuffle(self.audio_metadata) + self._filter() + + def _filter(self): + """ + Filter text & store spec lengths + """ + # Store spectrogram lengths for Bucketing + # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2) + # spec_length = wav_length // hop_length + + audio_metadata_new = [] + lengths = [] + + # for audiopath, sid, text in self.audio_metadata: + sid = 0 + spk_to_sid = {} + for wav_fpath, mel_fpath, embed_path, wav_length, mel_frames, text in self.audio_metadata: + if self.min_text_len <= len(text) and len(text) <= self.max_text_len: + # TODO: for magic data only + speaker_name = wav_fpath.split("_")[1] + if speaker_name not in spk_to_sid: + sid += 1 + spk_to_sid[speaker_name] = sid + + audio_metadata_new.append([wav_fpath, mel_fpath, embed_path, wav_length, mel_frames, text, spk_to_sid[speaker_name]]) + lengths.append(os.path.getsize(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}') // (2 * self.hop_length)) + print("found sid:%d", sid) + self.audio_metadata = audio_metadata_new + self.lengths = lengths + + def get_audio_text_speaker_pair(self, audio_metadata): + # separate filename, speaker_id and text + wav_fpath, text, sid = audio_metadata[0], audio_metadata[5], audio_metadata[6] + text = self.get_text(text) + + spec, wav = self.get_audio(f'{self.datasets_root}{os.sep}audio{os.sep}{wav_fpath}') + sid = self.get_sid(sid) + emo = torch.FloatTensor(np.load(f'{self.datasets_root}{os.sep}emo{os.sep}{wav_fpath.replace("audio", "emo")}')) + return (text, spec, wav, sid, emo) + + def get_audio(self, filename): + # audio, sampling_rate = load_wav(filename) + + # if sampling_rate != self.sampling_rate: + # raise ValueError("{} {} SR doesn't match target {} SR".format( + # sampling_rate, self.sampling_rate)) + # audio = torch.load(filename) + audio = torch.FloatTensor(np.load(filename).astype(np.float32)) + audio = audio.unsqueeze(0) + # audio_norm = audio / self.max_wav_value + # audio_norm = audio_norm.unsqueeze(0) + # spec_filename = filename.replace(".wav", ".spec.pt") + # if os.path.exists(spec_filename): + # spec = torch.load(spec_filename) + # else: + # spec = spectrogram(audio, self.filter_length, + # self.sampling_rate, self.hop_length, self.win_length, + # center=False) + # spec = torch.squeeze(spec, 0) + # torch.save(spec, spec_filename) + spec = spectrogram(audio, self.filter_length, self.hop_length, self.win_length, + center=False) + spec = torch.squeeze(spec, 0) + return spec, audio + + def get_text(self, text): + if self.cleaned_text: + text_norm = text_to_sequence(text, self.text_cleaners) + if self.add_blank: + text_norm = intersperse(text_norm, 0) + text_norm = torch.LongTensor(text_norm) + return text_norm + + def get_sid(self, sid): + sid = torch.LongTensor([int(sid)]) + return sid + + def __getitem__(self, index): + return self.get_audio_text_speaker_pair(self.audio_metadata[index]) + + def __len__(self): + return len(self.audio_metadata) + + +class VitsDatasetCollate(): + """ Zero-pads model inputs and targets + """ + def __init__(self, return_ids=False): + self.return_ids = return_ids + + def __call__(self, batch): + """Collate's training batch from normalized text, audio and speaker identities + PARAMS + ------ + batch: [text_normalized, spec_normalized, wav_normalized, sid] + """ + # Right zero-pad all one-hot text sequences to max input length + _, ids_sorted_decreasing = torch.sort( + torch.LongTensor([x[1].size(1) for x in batch]), + dim=0, descending=True) + + max_text_len = max([len(x[0]) for x in batch]) + max_spec_len = max([x[1].size(1) for x in batch]) + max_wav_len = max([x[2].size(1) for x in batch]) + + text_lengths = torch.LongTensor(len(batch)) + spec_lengths = torch.LongTensor(len(batch)) + wav_lengths = torch.LongTensor(len(batch)) + sid = torch.LongTensor(len(batch)) + + text_padded = torch.LongTensor(len(batch), max_text_len) + spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len) + wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len) + emo = torch.FloatTensor(len(batch), 1024) + + text_padded.zero_() + spec_padded.zero_() + wav_padded.zero_() + emo.zero_() + + for i in range(len(ids_sorted_decreasing)): + row = batch[ids_sorted_decreasing[i]] + + text = row[0] + text_padded[i, :text.size(0)] = text + text_lengths[i] = text.size(0) + + spec = row[1] + spec_padded[i, :, :spec.size(1)] = spec + spec_lengths[i] = spec.size(1) + + wav = row[2] + wav_padded[i, :, :wav.size(1)] = wav + wav_lengths[i] = wav.size(1) + + sid[i] = row[3] + + emo[i, :] = row[4] + + if self.return_ids: + return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, ids_sorted_decreasing + return text_padded, text_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, sid, emo + + +class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler): + """ + Maintain similar input lengths in a batch. + Length groups are specified by boundaries. + Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}. + + It removes samples which are not included in the boundaries. + Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded. + """ + def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + self.lengths = dataset.lengths + self.batch_size = batch_size + self.boundaries = boundaries + + self.buckets, self.num_samples_per_bucket = self._create_buckets() + self.total_size = sum(self.num_samples_per_bucket) + self.num_samples = self.total_size // self.num_replicas + + def _create_buckets(self): + buckets = [[] for _ in range(len(self.boundaries) - 1)] + for i in range(len(self.lengths)): + length = self.lengths[i] + idx_bucket = self._bisect(length) + if idx_bucket != -1: + buckets[idx_bucket].append(i) + + for i in range(len(buckets) - 1, 0, -1): + if len(buckets[i]) == 0: + buckets.pop(i) + self.boundaries.pop(i+1) + + num_samples_per_bucket = [] + for i in range(len(buckets)): + len_bucket = len(buckets[i]) + total_batch_size = self.num_replicas * self.batch_size + rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size + num_samples_per_bucket.append(len_bucket + rem) + return buckets, num_samples_per_bucket + + def __iter__(self): + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + + indices = [] + if self.shuffle: + for bucket in self.buckets: + indices.append(torch.randperm(len(bucket), generator=g).tolist()) + else: + for bucket in self.buckets: + indices.append(list(range(len(bucket)))) + + batches = [] + for i in range(len(self.buckets)): + bucket = self.buckets[i] + len_bucket = len(bucket) + ids_bucket = indices[i] + num_samples_bucket = self.num_samples_per_bucket[i] + + # add extra samples to make it evenly divisible + rem = num_samples_bucket - len_bucket + ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)] + + # subsample + ids_bucket = ids_bucket[self.rank::self.num_replicas] + + # batching + for j in range(len(ids_bucket) // self.batch_size): + batch = [bucket[idx] for idx in ids_bucket[j*self.batch_size:(j+1)*self.batch_size]] + batches.append(batch) + + if self.shuffle: + batch_ids = torch.randperm(len(batches), generator=g).tolist() + batches = [batches[i] for i in batch_ids] + self.batches = batches + + assert len(self.batches) * self.batch_size == self.num_samples + return iter(self.batches) + + def _bisect(self, x, lo=0, hi=None): + if hi is None: + hi = len(self.boundaries) - 1 + + if hi > lo: + mid = (hi + lo) // 2 + if self.boundaries[mid] < x and x <= self.boundaries[mid+1]: + return mid + elif x <= self.boundaries[mid]: + return self._bisect(x, lo, mid) + else: + return self._bisect(x, mid + 1, hi) + else: + return -1 + + def __len__(self): + return self.num_samples // self.batch_size diff --git a/models/vocoder/fregan/meldataset.py b/models/vocoder/fregan/meldataset.py index 53b2c94..df1964d 100644 --- a/models/vocoder/fregan/meldataset.py +++ b/models/vocoder/fregan/meldataset.py @@ -6,7 +6,7 @@ import torch.utils.data import numpy as np from librosa.util import normalize from scipy.io.wavfile import read -from librosa.filters import mel as librosa_mel_fn +from utils.audio_utils import mel_spectrogram MAX_WAV_VALUE = 32768.0 @@ -16,62 +16,6 @@ def load_wav(full_path): return data, sampling_rate -def dynamic_range_compression(x, C=1, clip_val=1e-5): - return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) - - -def dynamic_range_decompression(x, C=1): - return np.exp(x) / C - - -def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) - - -def dynamic_range_decompression_torch(x, C=1): - return torch.exp(x) / C - - -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - -mel_basis = {} -hann_window = {} - - -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) - - global mel_basis, hann_window - if fmax not in mel_basis: - mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) - - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') - y = y.squeeze(1) - - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], - center=center, pad_mode='reflect', normalized=False, onesided=True) - - spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) - - spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec - - def get_dataset_filelist(a): #with open(a.input_training_file, 'r', encoding='utf-8') as fi: # training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') diff --git a/models/vocoder/fregan/train.py b/models/vocoder/fregan/train.py index 53025ca..529c6cd 100644 --- a/models/vocoder/fregan/train.py +++ b/models/vocoder/fregan/train.py @@ -13,7 +13,7 @@ from torch.nn.parallel import DistributedDataParallel from models.vocoder.fregan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist from models.vocoder.fregan.generator import FreGAN from models.vocoder.fregan.discriminator import ResWiseMultiPeriodDiscriminator, ResWiseMultiScaleDiscriminator -from models.vocoder.fregan.loss import feature_loss, generator_loss, discriminator_loss +from utils.loss import feature_loss, generator_loss, discriminator_loss from models.vocoder.fregan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint diff --git a/models/vocoder/hifigan/meldataset.py b/models/vocoder/hifigan/meldataset.py index eb0682b..9b378ec 100644 --- a/models/vocoder/hifigan/meldataset.py +++ b/models/vocoder/hifigan/meldataset.py @@ -6,7 +6,7 @@ import torch.utils.data import numpy as np from librosa.util import normalize from scipy.io.wavfile import read -from librosa.filters import mel as librosa_mel_fn +from utils.audio_utils import mel_spectrogram MAX_WAV_VALUE = 32768.0 @@ -32,46 +32,6 @@ def dynamic_range_decompression_torch(x, C=1): return torch.exp(x) / C -def spectral_normalize_torch(magnitudes): - output = dynamic_range_compression_torch(magnitudes) - return output - - -def spectral_de_normalize_torch(magnitudes): - output = dynamic_range_decompression_torch(magnitudes) - return output - - -mel_basis = {} -hann_window = {} - - -def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False): - if torch.min(y) < -1.: - print('min value is ', torch.min(y)) - if torch.max(y) > 1.: - print('max value is ', torch.max(y)) - - global mel_basis, hann_window - if fmax not in mel_basis: - mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) - - y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') - y = y.squeeze(1) - - spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], - center=center, pad_mode='reflect', normalized=False, onesided=True) - - spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) - - spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) - spec = spectral_normalize_torch(spec) - - return spec - - def get_dataset_filelist(a): # with open(a.input_training_file, 'r', encoding='utf-8') as fi: # training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav') diff --git a/models/vocoder/hifigan/models.py b/models/vocoder/hifigan/models.py index 6da66ee..9f1a419 100644 --- a/models/vocoder/hifigan/models.py +++ b/models/vocoder/hifigan/models.py @@ -283,38 +283,3 @@ class MultiScaleDiscriminator(torch.nn.Module): fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs - - -def feature_loss(fmap_r, fmap_g): - loss = 0 - for dr, dg in zip(fmap_r, fmap_g): - for rl, gl in zip(dr, dg): - loss += torch.mean(torch.abs(rl - gl)) - - return loss*2 - - -def discriminator_loss(disc_real_outputs, disc_generated_outputs): - loss = 0 - r_losses = [] - g_losses = [] - for dr, dg in zip(disc_real_outputs, disc_generated_outputs): - r_loss = torch.mean((1-dr)**2) - g_loss = torch.mean(dg**2) - loss += (r_loss + g_loss) - r_losses.append(r_loss.item()) - g_losses.append(g_loss.item()) - - return loss, r_losses, g_losses - - -def generator_loss(disc_outputs): - loss = 0 - gen_losses = [] - for dg in disc_outputs: - l = torch.mean((1-dg)**2) - gen_losses.append(l) - loss += l - - return loss, gen_losses - diff --git a/models/vocoder/hifigan/train.py b/models/vocoder/hifigan/train.py index a2559b9..7a39071 100644 --- a/models/vocoder/hifigan/train.py +++ b/models/vocoder/hifigan/train.py @@ -13,8 +13,9 @@ import torch.multiprocessing as mp from torch.distributed import init_process_group from torch.nn.parallel import DistributedDataParallel from models.vocoder.hifigan.meldataset import MelDataset, mel_spectrogram, get_dataset_filelist -from models.vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator, feature_loss, generator_loss,\ - discriminator_loss +from models.vocoder.hifigan.models import Generator, MultiPeriodDiscriminator, MultiScaleDiscriminator +from utils.loss import feature_loss, generator_loss, discriminator_loss + from models.vocoder.hifigan.utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint torch.backends.cudnn.benchmark = True diff --git a/models/wav2emo/__init__.py b/models/wav2emo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/pre.py b/pre.py index 1723430..a750c7a 100644 --- a/pre.py +++ b/pre.py @@ -1,16 +1,11 @@ -from models.synthesizer.preprocess import create_embeddings -from utils.argutils import print_args -from pathlib import Path -import argparse - -from models.synthesizer.preprocess import preprocess_dataset +from models.synthesizer.preprocess import create_embeddings, preprocess_dataset from models.synthesizer.hparams import hparams -from utils.argutils import print_args from pathlib import Path import argparse recognized_datasets = [ "aidatatang_200zh", + "aidatatang_200zh_s", "magicdata", "aishell3", "data_aishell" @@ -48,6 +43,8 @@ if __name__ == "__main__": parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\ "Number of processes in parallel.An encoder is created for each, so you may need to lower " "this value on GPUs with low memory. Set it to 1 if CUDA is unhappy") + parser.add_argument("-ee","--emotion_extract", action="store_true", help=\ + "Preprocess audio to extract emotional numpy (for emotional vits).") args = parser.parse_args() # Process the arguments @@ -74,4 +71,5 @@ if __name__ == "__main__": del args.n_processes_embed preprocess_dataset(**vars(args)) - create_embeddings(synthesizer_root=args.out_dir, n_processes=n_processes_embed, encoder_model_fpath=encoder_model_fpath) + create_embeddings(synthesizer_root=args.out_dir, n_processes=n_processes_embed, encoder_model_fpath=encoder_model_fpath) + diff --git a/requirements.txt b/requirements.txt index 459b0fa..ad087f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ flask_wtf flask_cors==3.0.10 gevent==21.8.0 flask_restx -tensorboard +tensorboard==1.15 streamlit==1.8.0 PyYAML==5.4.1 torch_complex diff --git a/run.py b/run.py index 904029a..a2cfe96 100644 --- a/run.py +++ b/run.py @@ -2,14 +2,13 @@ import time import os import argparse import torch -import numpy as np import glob from pathlib import Path from tqdm import tqdm from models.ppg_extractor import load_model import librosa import soundfile as sf -from utils.load_yaml import HpsYaml +from utils.hparams import HpsYaml from models.encoder.audio import preprocess_wav from models.encoder import inference as speacker_encoder diff --git a/train.py b/train.py index b5499bb..0268607 100644 --- a/train.py +++ b/train.py @@ -1,9 +1,4 @@ -import sys -import torch import argparse -import numpy as np -from utils.load_yaml import HpsYaml -from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver def main(): # Arguments @@ -17,6 +12,9 @@ def main(): if paras.type == "synth": from control.cli.synthesizer_train import new_train new_train() + if paras.type == "vits": + from models.synthesizer.train_vits import new_train + new_train() if __name__ == "__main__": main() diff --git a/utils/audio_utils.py b/utils/audio_utils.py index 1dbeddb..bed38b5 100644 --- a/utils/audio_utils.py +++ b/utils/audio_utils.py @@ -1,4 +1,4 @@ - +import numpy as np import torch import torch.utils.data from scipy.io.wavfile import read @@ -6,21 +6,50 @@ from librosa.filters import mel as librosa_mel_fn MAX_WAV_VALUE = 32768.0 +mel_basis = {} +hann_window = {} def load_wav(full_path): sampling_rate, data = read(full_path) return data, sampling_rate -def _dynamic_range_compression_torch(x, C=1, clip_val=1e-5): - return torch.log(torch.clamp(x, min=clip_val) * C) +def load_wav_to_torch(full_path): + sampling_rate, data = read(full_path) + return torch.FloatTensor(data.astype(np.float32)), sampling_rate -def _spectral_normalize_torch(magnitudes): - output = _dynamic_range_compression_torch(magnitudes) - return output +def spectrogram(y, n_fft, hop_size, win_size, center=False): + if torch.min(y) < -1.: + print('min value is ', torch.min(y)) + if torch.max(y) > 1.: + print('max value is ', torch.max(y)) + + global hann_window + dtype_device = str(y.dtype) + '_' + str(y.device) + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + return spec + +def spec_to_mel(spec, n_fft, num_mels, sampling_rate, fmin, fmax): + global mel_basis + dtype_device = str(spec.dtype) + '_' + str(spec.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) + spec = torch.matmul(mel_basis[fmax_dtype_device], spec) + spec = _spectral_normalize_torch(spec) + return spec -mel_basis = {} -hann_window = {} def mel_spectrogram( y, @@ -39,18 +68,27 @@ def mel_spectrogram( if torch.max(y) > 1.: print('max value is ', torch.max(y)) + # global mel_basis, hann_window + # if fmax not in mel_basis: + # mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + # mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) + # hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) global mel_basis, hann_window - if fmax not in mel_basis: + dtype_device = str(y.dtype) + '_' + str(y.device) + fmax_dtype_device = str(fmax) + '_' + dtype_device + wnsize_dtype_device = str(win_size) + '_' + dtype_device + if fmax_dtype_device not in mel_basis: mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) - mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device) - hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device) + if wnsize_dtype_device not in hann_window: + hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect') y = y.squeeze(1) spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=False) - spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9)) + spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-6)) mel_spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec) mel_spec = _spectral_normalize_torch(mel_spec) if output_energy: @@ -58,3 +96,12 @@ def mel_spectrogram( return mel_spec, energy else: return mel_spec + + +def _dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def _spectral_normalize_torch(magnitudes): + output = _dynamic_range_compression_torch(magnitudes) + return output diff --git a/utils/hparams.py b/utils/hparams.py new file mode 100644 index 0000000..04ff2fe --- /dev/null +++ b/utils/hparams.py @@ -0,0 +1,110 @@ +import yaml +import json +import ast + +def load_hparams_json(filename): + with open(filename, "r") as f: + data = f.read() + config = json.loads(data) + + hparams = HParams(**config) + return hparams + + +def load_hparams_yaml(filename): + stream = open(filename, 'r') + docs = yaml.safe_load_all(stream) + hparams_dict = dict() + for doc in docs: + for k, v in doc.items(): + hparams_dict[k] = v + return hparams_dict + +def merge_dict(user, default): + if isinstance(user, dict) and isinstance(default, dict): + for k, v in default.items(): + if k not in user: + user[k] = v + else: + user[k] = merge_dict(user[k], v) + return user + +class Dotdict(dict): + """ + a dictionary that supports dot notation + as well as dictionary access notation + usage: d = DotDict() or d = DotDict({'val1':'first'}) + set attributes: d.val2 = 'second' or d['val2'] = 'second' + get attributes: d.val2 or d['val2'] + """ + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __init__(self, dct=None): + dct = dict() if not dct else dct + for key, value in dct.items(): + if hasattr(value, 'keys'): + value = Dotdict(value) + self[key] = value + +class HpsYaml(Dotdict): + def __init__(self, yaml_file): + super(Dotdict, self).__init__() + hps = load_hparams_yaml(yaml_file) + hp_dict = Dotdict(hps) + for k, v in hp_dict.items(): + setattr(self, k, v) + + __getattr__ = Dotdict.__getitem__ + __setattr__ = Dotdict.__setitem__ + __delattr__ = Dotdict.__delitem__ + +class HParams(): + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HParams(**v) + self[k] = v + def keys(self): + return self.__dict__.keys() + def __setitem__(self, key, value): setattr(self, key, value) + def __getitem__(self, key): return getattr(self, key) + def keys(self): return self.__dict__.keys() + def items(self): return self.__dict__.items() + def values(self): return self.__dict__.values() + def __contains__(self, key): return key in self.__dict__ + def __repr__(self): + return self.__dict__.__repr__() + + def parse(self, string): + # Overrides hparams from a comma-separated string of name=value pairs + if len(string) > 0: + overrides = [s.split("=") for s in string.split(",")] + keys, values = zip(*overrides) + keys = list(map(str.strip, keys)) + values = list(map(str.strip, values)) + for k in keys: + self.__dict__[k] = ast.literal_eval(values[keys.index(k)]) + return self + + def loadJson(self, fpath): + with fpath.open("r", encoding="utf-8") as f: + print("\Loading the json with %s\n", fpath) + data = json.load(f) + for k in data.keys(): + if k not in ["tts_schedule", "tts_finetune_layers"]: + v = data[k] + if type(v) == dict: + v = HParams(**v) + self.__dict__[k] = v + return self + + def dumpJson(self, fp): + print("\Saving the json with %s\n", fp) + with fp.open("w", encoding="utf-8") as f: + json.dump(self.__dict__, f) + return self + + + diff --git a/utils/load_yaml.py b/utils/load_yaml.py deleted file mode 100644 index 5792ff4..0000000 --- a/utils/load_yaml.py +++ /dev/null @@ -1,58 +0,0 @@ -import yaml - - -def load_hparams(filename): - stream = open(filename, 'r') - docs = yaml.safe_load_all(stream) - hparams_dict = dict() - for doc in docs: - for k, v in doc.items(): - hparams_dict[k] = v - return hparams_dict - -def merge_dict(user, default): - if isinstance(user, dict) and isinstance(default, dict): - for k, v in default.items(): - if k not in user: - user[k] = v - else: - user[k] = merge_dict(user[k], v) - return user - -class Dotdict(dict): - """ - a dictionary that supports dot notation - as well as dictionary access notation - usage: d = DotDict() or d = DotDict({'val1':'first'}) - set attributes: d.val2 = 'second' or d['val2'] = 'second' - get attributes: d.val2 or d['val2'] - """ - __getattr__ = dict.__getitem__ - __setattr__ = dict.__setitem__ - __delattr__ = dict.__delitem__ - - def __init__(self, dct=None): - dct = dict() if not dct else dct - for key, value in dct.items(): - if hasattr(value, 'keys'): - value = Dotdict(value) - self[key] = value - -class HpsYaml(Dotdict): - def __init__(self, yaml_file): - super(Dotdict, self).__init__() - hps = load_hparams(yaml_file) - hp_dict = Dotdict(hps) - for k, v in hp_dict.items(): - setattr(self, k, v) - - __getattr__ = Dotdict.__getitem__ - __setattr__ = Dotdict.__setitem__ - __delattr__ = Dotdict.__delitem__ - - - - - - - diff --git a/models/vocoder/fregan/loss.py b/utils/loss.py similarity index 65% rename from models/vocoder/fregan/loss.py rename to utils/loss.py index e37dc64..89d582a 100644 --- a/models/vocoder/fregan/loss.py +++ b/utils/loss.py @@ -32,4 +32,22 @@ def generator_loss(disc_outputs): gen_losses.append(l) loss += l - return loss, gen_losses \ No newline at end of file + return loss, gen_losses + + +def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): + """ + z_p, logs_q: [b, h, t_t] + m_p, logs_p: [b, h, t_t] + """ + z_p = z_p.float() + logs_q = logs_q.float() + m_p = m_p.float() + logs_p = logs_p.float() + z_mask = z_mask.float() + + kl = logs_p - logs_q - 0.5 + kl += 0.5 * ((z_p - m_p)**2) * torch.exp(-2. * logs_p) + kl = torch.sum(kl * z_mask) + l = kl / torch.sum(z_mask) + return l diff --git a/utils/util.py b/utils/util.py index 3ec9190..56a47ce 100644 --- a/utils/util.py +++ b/utils/util.py @@ -125,4 +125,22 @@ def subsequent_mask(length): def intersperse(lst, item): result = [item] * (len(lst) * 2 + 1) result[1::2] = lst - return result \ No newline at end of file + return result + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1. / norm_type) + return total_norm diff --git a/vits.ipynb b/vits.ipynb new file mode 100644 index 0000000..c0ff3e6 --- /dev/null +++ b/vits.ipynb @@ -0,0 +1,408 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'log_interval': 2000, 'eval_interval': 4000, 'seed': 1234, 'epochs': 10000, 'learning_rate': 0.0001, 'betas': [0.8, 0.99], 'eps': 1e-09, 'batch_size': 16, 'fp16_run': True, 'lr_decay': 0.5, 'segment_size': 8192, 'init_lr_ratio': 1, 'warmup_epochs': 0, 'c_mel': 45, 'c_kl': 1.0}\n", + "Trainable Parameters: 0.000M\n" + ] + } + ], + "source": [ + "from utils.hparams import load_hparams_json\n", + "from utils.util import intersperse\n", + "import json\n", + "from models.synthesizer.models.vits import Vits\n", + "import torch\n", + "import numpy as np\n", + "import IPython.display as ipd\n", + "\n", + "# chinese_cleaners\n", + "_pad = '_'\n", + "_punctuation = ',。!?—…'\n", + "_letters = 'ㄅㄆㄇㄈㄉㄊㄋㄌㄍㄎㄏㄐㄑㄒㄓㄔㄕㄖㄗㄘㄙㄚㄛㄜㄝㄞㄟㄠㄡㄢㄣㄤㄥㄦㄧㄨㄩˉˊˇˋ˙ '\n", + "# Export all symbols:\n", + "symbols = [_pad] + list(_punctuation) + list(_letters)\n", + "\n", + "hps = load_hparams_json(\"data/ckpt/synthesizer/vits/config.json\")\n", + "print(hps.train)\n", + "model = Vits(\n", + " len(symbols),\n", + " hps[\"data\"][\"filter_length\"] // 2 + 1,\n", + " hps[\"train\"][\"segment_size\"] // hps[\"data\"][\"hop_length\"],\n", + " n_speakers=hps[\"data\"][\"n_speakers\"],\n", + " stop_threshold=0.5,\n", + " **hps[\"model\"])\n", + "_ = model.eval()\n", + "device = torch.device(\"cpu\")\n", + "model.load(\"data/ckpt/synthesizer/vits/G_208000.pth\", device)\n", + "\n", + "# 随机抽取情感参考音频的根目录\n", + "random_emotion_root = \"D:\\\\audiodata\\\\aidatatang_200zh\\\\corpus\\\\train\\\\G0017\"\n", + "import random, re\n", + "# import cn2an # remove dependency before production\n", + "from pypinyin import lazy_pinyin, BOPOMOFO\n", + "\n", + "_symbol_to_id = {s: i for i, s in enumerate(symbols)}\n", + "\n", + "# def number_to_chinese(text):\n", + "# numbers = re.findall(r'\\d+(?:\\.?\\d+)?', text)\n", + "# for number in numbers:\n", + "# text = text.replace(number, cn2an.an2cn(number), 1)\n", + "# return text\n", + "\n", + "def chinese_to_bopomofo(text, taiwanese=False):\n", + " text = text.replace('、', ',').replace(';', ',').replace(':', ',')\n", + " for word in list(text):\n", + " bopomofos = lazy_pinyin(word, BOPOMOFO)\n", + " if not re.search('[\\u4e00-\\u9fff]', word):\n", + " text += word\n", + " continue\n", + " for i in range(len(bopomofos)):\n", + " bopomofos[i] = re.sub(r'([\\u3105-\\u3129])$', r'\\1ˉ', bopomofos[i])\n", + " if text != '':\n", + " text += ' '\n", + " if taiwanese:\n", + " text += '#'+'#'.join(bopomofos)\n", + " else:\n", + " text += ''.join(bopomofos)\n", + " return text\n", + "\n", + "_latin_to_bopomofo = [(re.compile('%s' % x[0], re.IGNORECASE), x[1]) for x in [\n", + " ('a', 'ㄟˉ'),\n", + " ('b', 'ㄅㄧˋ'),\n", + " ('c', 'ㄙㄧˉ'),\n", + " ('d', 'ㄉㄧˋ'),\n", + " ('e', 'ㄧˋ'),\n", + " ('f', 'ㄝˊㄈㄨˋ'),\n", + " ('g', 'ㄐㄧˋ'),\n", + " ('h', 'ㄝˇㄑㄩˋ'),\n", + " ('i', 'ㄞˋ'),\n", + " ('j', 'ㄐㄟˋ'),\n", + " ('k', 'ㄎㄟˋ'),\n", + " ('l', 'ㄝˊㄛˋ'),\n", + " ('m', 'ㄝˊㄇㄨˋ'),\n", + " ('n', 'ㄣˉ'),\n", + " ('o', 'ㄡˉ'),\n", + " ('p', 'ㄆㄧˉ'),\n", + " ('q', 'ㄎㄧㄡˉ'),\n", + " ('r', 'ㄚˋ'),\n", + " ('s', 'ㄝˊㄙˋ'),\n", + " ('t', 'ㄊㄧˋ'),\n", + " ('u', 'ㄧㄡˉ'),\n", + " ('v', 'ㄨㄧˉ'),\n", + " ('w', 'ㄉㄚˋㄅㄨˋㄌㄧㄡˋ'),\n", + " ('x', 'ㄝˉㄎㄨˋㄙˋ'),\n", + " ('y', 'ㄨㄞˋ'),\n", + " ('z', 'ㄗㄟˋ')\n", + "]]\n", + "\n", + "def latin_to_bopomofo(text):\n", + " for regex, replacement in _latin_to_bopomofo:\n", + " text = re.sub(regex, replacement, text)\n", + " return text\n", + "\n", + "#TODO: add cleaner to support multilang\n", + "def chinese_cleaners(text, cleaner_names):\n", + " '''Pipeline for Chinese text'''\n", + " # text = number_to_chinese(text)\n", + " text = chinese_to_bopomofo(text)\n", + " text = latin_to_bopomofo(text)\n", + " if re.match('[ˉˊˇˋ˙]', text[-1]):\n", + " text += '。'\n", + " return text\n", + "\n", + "\n", + "def text_to_sequence(text, cleaner_names):\n", + " '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.\n", + " Args:\n", + " text: string to convert to a sequence\n", + " cleaner_names: names of the cleaner functions to run the text through\n", + " Returns:\n", + " List of integers corresponding to the symbols in the text\n", + " '''\n", + " sequence = []\n", + "\n", + " clean_text = chinese_cleaners(text, cleaner_names)\n", + " for symbol in clean_text:\n", + " if symbol not in _symbol_to_id.keys():\n", + " continue\n", + " symbol_id = _symbol_to_id[symbol]\n", + " sequence += [symbol_id]\n", + " return sequence\n", + "\n", + "import os\n", + "\n", + "def tts(txt, emotion, sid=0):\n", + " text_norm = text_to_sequence(txt, hps[\"data\"][\"text_cleaners\"])\n", + " if hps[\"data\"][\"add_blank\"]:\n", + " text_norm = intersperse(text_norm, 0)\n", + " stn_tst = torch.LongTensor(text_norm)\n", + "\n", + " with torch.no_grad(): #inference mode\n", + " x_tst = stn_tst.unsqueeze(0)\n", + " x_tst_lengths = torch.LongTensor([stn_tst.size(0)])\n", + " sid = torch.LongTensor([sid])\n", + " if emotion.endswith(\"wav\"):\n", + " from models.synthesizer.preprocess_audio import extract_emo\n", + " import librosa\n", + " wav, sr = librosa.load(emotion, 16000)\n", + " emo = torch.FloatTensor(extract_emo(np.expand_dims(wav, 0), sr, embeddings=True))\n", + " else:\n", + " print(\"emotion参数不正确\")\n", + "\n", + " audio = model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.8, length_scale=1, emo=emo)[0][0,0].data.float().numpy()\n", + " ipd.display(ipd.Audio(audio, rate=hps[\"data\"][\"sampling_rate\"], normalize=False))\n", + "\n", + "\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "推理:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "txt = \"随机抽取的音频文件路径可以用于使用该情感合成其他句子\"\n", + "tts(txt, emotion='C:\\\\Users\\\\babys\\\\Desktop\\\\voicecollection\\\\secondround\\\\美玉.wav', sid=0)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "预处理:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using data from:\n", + " ..\\audiodata\\magicdata\\train\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "magicdata: 0%| | 0/1018 [00:00here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "from models.synthesizer.preprocess import preprocess_dataset\n", + "from pathlib import Path\n", + "from utils.hparams import HParams\n", + "datasets_root = Path(\"../audiodata/\")\n", + "hparams=HParams(\n", + " sample_rate = 16000,\n", + " rescale = True,\n", + " max_mel_frames = 900,\n", + " rescaling_max = 0.9,\n", + "\n", + " utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded\n", + " ### Audio processing options\n", + " fmax = 7600, # Should not exceed (sample_rate // 2)\n", + " allow_clipping_in_normalization = True, # Used when signal_normalization = True\n", + " clip_mels_length = True, # If true, discards samples exceeding max_mel_frames\n", + " use_lws = False, # \"Fast spectrogram phase recovery using local weighted sums\"\n", + " symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,\n", + " # and [0, max_abs_value] if False\n", + " trim_silence = True, # Use with sample_rate of 16000 for best results\n", + "\n", + ")\n", + "preprocess_dataset(datasets_root=datasets_root, \n", + " out_dir=datasets_root.joinpath(\"SV2TTS\", \"synthesizer\"),\n", + " n_processes=8,\n", + " skip_existing=True, \n", + " hparams=hparams, \n", + " no_alignments=False, \n", + " dataset=\"magicdata\", \n", + " emotion_extract=True)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "训练:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\\Loading the json with %s\n", + " data\\ckpt\\synthesizer\\vits\\config.json\n" + ] + }, + { + "ename": "ProcessRaisedException", + "evalue": "\n\n-- Process 0 terminated with the following error:\nTraceback (most recent call last):\n File \"d:\\Users\\babys\\Anaconda3\\envs\\mo\\lib\\site-packages\\torch\\multiprocessing\\spawn.py\", line 59, in _wrap\n fn(i, *args)\n File \"d:\\Real-Time-Voice-Cloning-Chinese\\models\\synthesizer\\train_vits.py\", line 123, in run\n net_g = Vits(\nTypeError: __init__() missing 1 required positional argument: 'stop_threshold'\n", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mProcessRaisedException\u001b[0m Traceback (most recent call last)", + "\u001b[1;32md:\\Real-Time-Voice-Cloning-Chinese\\vits.ipynb Cell 7\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 18\u001b[0m os\u001b[39m.\u001b[39menviron[\u001b[39m'\u001b[39m\u001b[39mMASTER_ADDR\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39mlocalhost\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m 19\u001b[0m os\u001b[39m.\u001b[39menviron[\u001b[39m'\u001b[39m\u001b[39mMASTER_PORT\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39m8899\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m---> 20\u001b[0m mp\u001b[39m.\u001b[39;49mspawn(run, nprocs\u001b[39m=\u001b[39;49mn_gpus, args\u001b[39m=\u001b[39;49m(n_gpus, hparams))\n", + "File \u001b[1;32md:\\Users\\babys\\Anaconda3\\envs\\mo\\lib\\site-packages\\torch\\multiprocessing\\spawn.py:230\u001b[0m, in \u001b[0;36mspawn\u001b[1;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[0;32m 226\u001b[0m msg \u001b[39m=\u001b[39m (\u001b[39m'\u001b[39m\u001b[39mThis method only supports start_method=spawn (got: \u001b[39m\u001b[39m%s\u001b[39;00m\u001b[39m).\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[0;32m 227\u001b[0m \u001b[39m'\u001b[39m\u001b[39mTo use a different start_method use:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m\\t\u001b[39;00m\u001b[39m'\u001b[39m\n\u001b[0;32m 228\u001b[0m \u001b[39m'\u001b[39m\u001b[39m torch.multiprocessing.start_processes(...)\u001b[39m\u001b[39m'\u001b[39m \u001b[39m%\u001b[39m start_method)\n\u001b[0;32m 229\u001b[0m warnings\u001b[39m.\u001b[39mwarn(msg)\n\u001b[1;32m--> 230\u001b[0m \u001b[39mreturn\u001b[39;00m start_processes(fn, args, nprocs, join, daemon, start_method\u001b[39m=\u001b[39;49m\u001b[39m'\u001b[39;49m\u001b[39mspawn\u001b[39;49m\u001b[39m'\u001b[39;49m)\n", + "File \u001b[1;32md:\\Users\\babys\\Anaconda3\\envs\\mo\\lib\\site-packages\\torch\\multiprocessing\\spawn.py:188\u001b[0m, in \u001b[0;36mstart_processes\u001b[1;34m(fn, args, nprocs, join, daemon, start_method)\u001b[0m\n\u001b[0;32m 185\u001b[0m \u001b[39mreturn\u001b[39;00m context\n\u001b[0;32m 187\u001b[0m \u001b[39m# Loop on join until it returns True or raises an exception.\u001b[39;00m\n\u001b[1;32m--> 188\u001b[0m \u001b[39mwhile\u001b[39;00m \u001b[39mnot\u001b[39;00m context\u001b[39m.\u001b[39;49mjoin():\n\u001b[0;32m 189\u001b[0m \u001b[39mpass\u001b[39;00m\n", + "File \u001b[1;32md:\\Users\\babys\\Anaconda3\\envs\\mo\\lib\\site-packages\\torch\\multiprocessing\\spawn.py:150\u001b[0m, in \u001b[0;36mProcessContext.join\u001b[1;34m(self, timeout)\u001b[0m\n\u001b[0;32m 148\u001b[0m msg \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\\n\u001b[39;00m\u001b[39m-- Process \u001b[39m\u001b[39m%d\u001b[39;00m\u001b[39m terminated with the following error:\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m \u001b[39m%\u001b[39m error_index\n\u001b[0;32m 149\u001b[0m msg \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m original_trace\n\u001b[1;32m--> 150\u001b[0m \u001b[39mraise\u001b[39;00m ProcessRaisedException(msg, error_index, failed_process\u001b[39m.\u001b[39mpid)\n", + "\u001b[1;31mProcessRaisedException\u001b[0m: \n\n-- Process 0 terminated with the following error:\nTraceback (most recent call last):\n File \"d:\\Users\\babys\\Anaconda3\\envs\\mo\\lib\\site-packages\\torch\\multiprocessing\\spawn.py\", line 59, in _wrap\n fn(i, *args)\n File \"d:\\Real-Time-Voice-Cloning-Chinese\\models\\synthesizer\\train_vits.py\", line 123, in run\n net_g = Vits(\nTypeError: __init__() missing 1 required positional argument: 'stop_threshold'\n" + ] + }, + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the the current cell or a previous cell. Please review the code in the cell(s) to identify a possible cause of the failure. Click here for more info. View Jupyter log for further details." + ] + } + ], + "source": [ + "from models.synthesizer.train_vits import run\n", + "from pathlib import Path\n", + "from utils.hparams import HParams\n", + "import torch, os\n", + "import torch.multiprocessing as mp\n", + "\n", + "datasets_root = Path(\"../audiodata/SV2TTS/synthesizer\")\n", + "hparams= HParams(\n", + " model_dir = \"data/ckpt/synthesizer/vits\",\n", + ")\n", + "hparams.loadJson(Path(hparams.model_dir).joinpath(\"config.json\"))\n", + "hparams.data[\"training_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n", + "hparams.data[\"validation_files\"] = str(datasets_root.joinpath(\"train.txt\"))\n", + "hparams.data[\"datasets_root\"] = str(datasets_root)\n", + "\n", + "n_gpus = torch.cuda.device_count()\n", + "# for spawn\n", + "os.environ['MASTER_ADDR'] = 'localhost'\n", + "os.environ['MASTER_PORT'] = '8899'\n", + "mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hparams))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "挑选只有对应emo文件的meta数据" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "import os\n", + "root = Path('../audiodata/SV2TTS/synthesizer')\n", + "dict_info = []\n", + "with open(root.joinpath(\"train.txt\"), \"r\", encoding=\"utf-8\") as dict_meta:\n", + " for raw in dict_meta:\n", + " if not raw:\n", + " continue\n", + " v = raw.split(\"|\")[0].replace(\"audio\",\"emo\")\n", + " emo_fpath = root.joinpath(\"emo\").joinpath(v)\n", + " if emo_fpath.exists():\n", + " dict_info.append(raw)\n", + " # else:\n", + " # print(emo_fpath)\n", + "# Iterate over each wav\n", + "meta2 = Path('../audiodata/SV2TTS/synthesizer/train2.txt')\n", + "metadata_file = meta2.open(\"w\", encoding=\"utf-8\")\n", + "for new_info in dict_info:\n", + " metadata_file.write(new_info)\n", + "metadata_file.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "mo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + }, + "vscode": { + "interpreter": { + "hash": "788ab866da3baa6c99886d56abb59fe71b6a552bf52c65473ecf96c784704db8" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}