#!/usr/bin/env python3 # Copyright 2020 Songxiang Liu # Apache 2.0 from typing import List import torch import torch.nn.functional as F import numpy as np from .utils.abs_model import AbsMelDecoder from .rnn_decoder_mol import Decoder from .utils.cnn_postnet import Postnet from .utils.vc_utils import get_mask_from_lengths from utils.load_yaml import HpsYaml class MelDecoderMOLv2(AbsMelDecoder): """Use an encoder to preprocess ppg.""" def __init__( self, num_speakers: int, spk_embed_dim: int, bottle_neck_feature_dim: int, encoder_dim: int = 256, encoder_downsample_rates: List = [2, 2], attention_rnn_dim: int = 512, decoder_rnn_dim: int = 512, num_decoder_rnn_layer: int = 1, concat_context_to_last: bool = True, prenet_dims: List = [256, 128], num_mixtures: int = 5, frames_per_step: int = 2, mask_padding: bool = True, ): super().__init__() self.mask_padding = mask_padding self.bottle_neck_feature_dim = bottle_neck_feature_dim self.num_mels = 80 self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1] self.frames_per_step = frames_per_step self.use_spk_dvec = True input_dim = bottle_neck_feature_dim # Downsampling convolution self.bnf_prenet = torch.nn.Sequential( torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False), torch.nn.LeakyReLU(0.1), torch.nn.InstanceNorm1d(encoder_dim, affine=False), torch.nn.Conv1d( encoder_dim, encoder_dim, kernel_size=2*encoder_downsample_rates[0], stride=encoder_downsample_rates[0], padding=encoder_downsample_rates[0]//2, ), torch.nn.LeakyReLU(0.1), torch.nn.InstanceNorm1d(encoder_dim, affine=False), torch.nn.Conv1d( encoder_dim, encoder_dim, kernel_size=2*encoder_downsample_rates[1], stride=encoder_downsample_rates[1], padding=encoder_downsample_rates[1]//2, ), torch.nn.LeakyReLU(0.1), torch.nn.InstanceNorm1d(encoder_dim, affine=False), ) decoder_enc_dim = encoder_dim self.pitch_convs = torch.nn.Sequential( torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False), torch.nn.LeakyReLU(0.1), torch.nn.InstanceNorm1d(encoder_dim, affine=False), torch.nn.Conv1d( encoder_dim, encoder_dim, kernel_size=2*encoder_downsample_rates[0], stride=encoder_downsample_rates[0], padding=encoder_downsample_rates[0]//2, ), torch.nn.LeakyReLU(0.1), torch.nn.InstanceNorm1d(encoder_dim, affine=False), torch.nn.Conv1d( encoder_dim, encoder_dim, kernel_size=2*encoder_downsample_rates[1], stride=encoder_downsample_rates[1], padding=encoder_downsample_rates[1]//2, ), torch.nn.LeakyReLU(0.1), torch.nn.InstanceNorm1d(encoder_dim, affine=False), ) self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim) # Decoder self.decoder = Decoder( enc_dim=decoder_enc_dim, num_mels=self.num_mels, frames_per_step=frames_per_step, attention_rnn_dim=attention_rnn_dim, decoder_rnn_dim=decoder_rnn_dim, num_decoder_rnn_layer=num_decoder_rnn_layer, prenet_dims=prenet_dims, num_mixtures=num_mixtures, use_stop_tokens=True, concat_context_to_last=concat_context_to_last, encoder_down_factor=self.encoder_down_factor, ) # Mel-Spec Postnet: some residual CNN layers self.postnet = Postnet() def parse_output(self, outputs, output_lengths=None): if self.mask_padding and output_lengths is not None: mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1)) mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels) outputs[0].data.masked_fill_(mask, 0.0) outputs[1].data.masked_fill_(mask, 0.0) return outputs def forward( self, bottle_neck_features: torch.Tensor, feature_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, logf0_uv: torch.Tensor = None, spembs: torch.Tensor = None, output_att_ws: bool = False, ): decoder_inputs = self.bnf_prenet( bottle_neck_features.transpose(1, 2) ).transpose(1, 2) logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2) decoder_inputs = decoder_inputs + logf0_uv assert spembs is not None spk_embeds = F.normalize( spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1) decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1) decoder_inputs = self.reduce_proj(decoder_inputs) # (B, num_mels, T_dec) T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor') mel_outputs, predicted_stop, alignments = self.decoder( decoder_inputs, speech, T_dec) ## Post-processing mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) mel_outputs_postnet = mel_outputs + mel_outputs_postnet if output_att_ws: return self.parse_output( [mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths) else: return self.parse_output( [mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths) # return mel_outputs, mel_outputs_postnet def inference( self, bottle_neck_features: torch.Tensor, logf0_uv: torch.Tensor = None, spembs: torch.Tensor = None, ): decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2) logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2) decoder_inputs = decoder_inputs + logf0_uv assert spembs is not None spk_embeds = F.normalize( spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1) bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1) bottle_neck_features = self.reduce_proj(bottle_neck_features) ## Decoder if bottle_neck_features.size(0) > 1: mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features) else: mel_outputs, alignments = self.decoder.inference(bottle_neck_features,) ## Post-processing mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2) mel_outputs_postnet = mel_outputs + mel_outputs_postnet # outputs = mel_outputs_postnet[0] return mel_outputs[0], mel_outputs_postnet[0], alignments[0] def load_model(model_file, device=None): # search a config file model_config_fpaths = list(model_file.parent.rglob("*.yaml")) if len(model_config_fpaths) == 0: raise "No model yaml config found for convertor" if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_config = HpsYaml(model_config_fpaths[0]) ppg2mel_model = MelDecoderMOLv2( **model_config["model"] ).to(device) ckpt = torch.load(model_file, map_location=device) ppg2mel_model.load_state_dict(ckpt["model"]) ppg2mel_model.eval() return ppg2mel_model