mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
b617a87ee4
* Init ppg extractor and ppg2mel * add preprocess and training * FIx known issues * Update __init__.py Allow to gen audio * Fix length issue * Fix bug of preparing fid * Fix sample issues * Add UI usage of PPG-vc
51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
from typing import Dict
|
|
from typing import Tuple
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from ..utils.nets_utils import make_pad_mask
|
|
|
|
|
|
class MaskedMSELoss(nn.Module):
|
|
def __init__(self, frames_per_step):
|
|
super().__init__()
|
|
self.frames_per_step = frames_per_step
|
|
self.mel_loss_criterion = nn.MSELoss(reduction='none')
|
|
# self.loss = nn.MSELoss()
|
|
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
|
|
|
|
def get_mask(self, lengths, max_len=None):
|
|
# lengths: [B,]
|
|
if max_len is None:
|
|
max_len = torch.max(lengths)
|
|
batch_size = lengths.size(0)
|
|
seq_range = torch.arange(0, max_len).long()
|
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
|
|
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
|
|
return (seq_range_expand < seq_length_expand).float()
|
|
|
|
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
|
|
stop_target, stop_pred):
|
|
## process stop_target
|
|
B = stop_target.size(0)
|
|
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
|
|
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
|
|
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
|
|
|
|
mel_trg.requires_grad = False
|
|
# (B, T, 1)
|
|
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
|
|
# (B, T, D)
|
|
mel_mask = mel_mask.expand_as(mel_trg)
|
|
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
|
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
|
|
|
mel_loss = mel_loss_pre + mel_loss_post
|
|
|
|
# stop token loss
|
|
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
|
|
|
|
return mel_loss, stop_loss
|