mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
210 lines
7.8 KiB
Python
210 lines
7.8 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# Copyright 2020 Songxiang Liu
|
|
# Apache 2.0
|
|
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import numpy as np
|
|
|
|
from .utils.abs_model import AbsMelDecoder
|
|
from .rnn_decoder_mol import Decoder
|
|
from .utils.cnn_postnet import Postnet
|
|
from .utils.vc_utils import get_mask_from_lengths
|
|
|
|
from utils.hparams import HpsYaml
|
|
|
|
class MelDecoderMOLv2(AbsMelDecoder):
|
|
"""Use an encoder to preprocess ppg."""
|
|
def __init__(
|
|
self,
|
|
num_speakers: int,
|
|
spk_embed_dim: int,
|
|
bottle_neck_feature_dim: int,
|
|
encoder_dim: int = 256,
|
|
encoder_downsample_rates: List = [2, 2],
|
|
attention_rnn_dim: int = 512,
|
|
decoder_rnn_dim: int = 512,
|
|
num_decoder_rnn_layer: int = 1,
|
|
concat_context_to_last: bool = True,
|
|
prenet_dims: List = [256, 128],
|
|
num_mixtures: int = 5,
|
|
frames_per_step: int = 2,
|
|
mask_padding: bool = True,
|
|
):
|
|
super().__init__()
|
|
|
|
self.mask_padding = mask_padding
|
|
self.bottle_neck_feature_dim = bottle_neck_feature_dim
|
|
self.num_mels = 80
|
|
self.encoder_down_factor=np.cumprod(encoder_downsample_rates)[-1]
|
|
self.frames_per_step = frames_per_step
|
|
self.use_spk_dvec = True
|
|
|
|
input_dim = bottle_neck_feature_dim
|
|
|
|
# Downsampling convolution
|
|
self.bnf_prenet = torch.nn.Sequential(
|
|
torch.nn.Conv1d(input_dim, encoder_dim, kernel_size=1, bias=False),
|
|
torch.nn.LeakyReLU(0.1),
|
|
|
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
|
torch.nn.Conv1d(
|
|
encoder_dim, encoder_dim,
|
|
kernel_size=2*encoder_downsample_rates[0],
|
|
stride=encoder_downsample_rates[0],
|
|
padding=encoder_downsample_rates[0]//2,
|
|
),
|
|
torch.nn.LeakyReLU(0.1),
|
|
|
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
|
torch.nn.Conv1d(
|
|
encoder_dim, encoder_dim,
|
|
kernel_size=2*encoder_downsample_rates[1],
|
|
stride=encoder_downsample_rates[1],
|
|
padding=encoder_downsample_rates[1]//2,
|
|
),
|
|
torch.nn.LeakyReLU(0.1),
|
|
|
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
|
)
|
|
decoder_enc_dim = encoder_dim
|
|
self.pitch_convs = torch.nn.Sequential(
|
|
torch.nn.Conv1d(2, encoder_dim, kernel_size=1, bias=False),
|
|
torch.nn.LeakyReLU(0.1),
|
|
|
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
|
torch.nn.Conv1d(
|
|
encoder_dim, encoder_dim,
|
|
kernel_size=2*encoder_downsample_rates[0],
|
|
stride=encoder_downsample_rates[0],
|
|
padding=encoder_downsample_rates[0]//2,
|
|
),
|
|
torch.nn.LeakyReLU(0.1),
|
|
|
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
|
torch.nn.Conv1d(
|
|
encoder_dim, encoder_dim,
|
|
kernel_size=2*encoder_downsample_rates[1],
|
|
stride=encoder_downsample_rates[1],
|
|
padding=encoder_downsample_rates[1]//2,
|
|
),
|
|
torch.nn.LeakyReLU(0.1),
|
|
|
|
torch.nn.InstanceNorm1d(encoder_dim, affine=False),
|
|
)
|
|
|
|
self.reduce_proj = torch.nn.Linear(encoder_dim + spk_embed_dim, encoder_dim)
|
|
|
|
# Decoder
|
|
self.decoder = Decoder(
|
|
enc_dim=decoder_enc_dim,
|
|
num_mels=self.num_mels,
|
|
frames_per_step=frames_per_step,
|
|
attention_rnn_dim=attention_rnn_dim,
|
|
decoder_rnn_dim=decoder_rnn_dim,
|
|
num_decoder_rnn_layer=num_decoder_rnn_layer,
|
|
prenet_dims=prenet_dims,
|
|
num_mixtures=num_mixtures,
|
|
use_stop_tokens=True,
|
|
concat_context_to_last=concat_context_to_last,
|
|
encoder_down_factor=self.encoder_down_factor,
|
|
)
|
|
|
|
# Mel-Spec Postnet: some residual CNN layers
|
|
self.postnet = Postnet()
|
|
|
|
def parse_output(self, outputs, output_lengths=None):
|
|
if self.mask_padding and output_lengths is not None:
|
|
mask = ~get_mask_from_lengths(output_lengths, outputs[0].size(1))
|
|
mask = mask.unsqueeze(2).expand(mask.size(0), mask.size(1), self.num_mels)
|
|
outputs[0].data.masked_fill_(mask, 0.0)
|
|
outputs[1].data.masked_fill_(mask, 0.0)
|
|
return outputs
|
|
|
|
def forward(
|
|
self,
|
|
bottle_neck_features: torch.Tensor,
|
|
feature_lengths: torch.Tensor,
|
|
speech: torch.Tensor,
|
|
speech_lengths: torch.Tensor,
|
|
logf0_uv: torch.Tensor = None,
|
|
spembs: torch.Tensor = None,
|
|
output_att_ws: bool = False,
|
|
):
|
|
decoder_inputs = self.bnf_prenet(
|
|
bottle_neck_features.transpose(1, 2)
|
|
).transpose(1, 2)
|
|
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
|
decoder_inputs = decoder_inputs + logf0_uv
|
|
|
|
assert spembs is not None
|
|
spk_embeds = F.normalize(
|
|
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
|
decoder_inputs = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
|
decoder_inputs = self.reduce_proj(decoder_inputs)
|
|
|
|
# (B, num_mels, T_dec)
|
|
T_dec = torch.div(feature_lengths, int(self.encoder_down_factor), rounding_mode='floor')
|
|
mel_outputs, predicted_stop, alignments = self.decoder(
|
|
decoder_inputs, speech, T_dec)
|
|
## Post-processing
|
|
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
|
if output_att_ws:
|
|
return self.parse_output(
|
|
[mel_outputs, mel_outputs_postnet, predicted_stop, alignments], speech_lengths)
|
|
else:
|
|
return self.parse_output(
|
|
[mel_outputs, mel_outputs_postnet, predicted_stop], speech_lengths)
|
|
|
|
# return mel_outputs, mel_outputs_postnet
|
|
|
|
def inference(
|
|
self,
|
|
bottle_neck_features: torch.Tensor,
|
|
logf0_uv: torch.Tensor = None,
|
|
spembs: torch.Tensor = None,
|
|
):
|
|
decoder_inputs = self.bnf_prenet(bottle_neck_features.transpose(1, 2)).transpose(1, 2)
|
|
logf0_uv = self.pitch_convs(logf0_uv.transpose(1, 2)).transpose(1, 2)
|
|
decoder_inputs = decoder_inputs + logf0_uv
|
|
|
|
assert spembs is not None
|
|
spk_embeds = F.normalize(
|
|
spembs).unsqueeze(1).expand(-1, decoder_inputs.size(1), -1)
|
|
bottle_neck_features = torch.cat([decoder_inputs, spk_embeds], dim=-1)
|
|
bottle_neck_features = self.reduce_proj(bottle_neck_features)
|
|
|
|
## Decoder
|
|
if bottle_neck_features.size(0) > 1:
|
|
mel_outputs, alignments = self.decoder.inference_batched(bottle_neck_features)
|
|
else:
|
|
mel_outputs, alignments = self.decoder.inference(bottle_neck_features,)
|
|
## Post-processing
|
|
mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
|
|
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
|
# outputs = mel_outputs_postnet[0]
|
|
|
|
return mel_outputs[0], mel_outputs_postnet[0], alignments[0]
|
|
|
|
def load_model(model_file, device=None):
|
|
# search a config file
|
|
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
|
if len(model_config_fpaths) == 0:
|
|
raise "No model yaml config found for convertor"
|
|
if device is None:
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
model_config = HpsYaml(model_config_fpaths[0])
|
|
ppg2mel_model = MelDecoderMOLv2(
|
|
**model_config["model"]
|
|
).to(device)
|
|
ckpt = torch.load(model_file, map_location=device)
|
|
ppg2mel_model.load_state_dict(ckpt["model"])
|
|
ppg2mel_model.eval()
|
|
return ppg2mel_model
|