mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
51 lines
2.0 KiB
Python
51 lines
2.0 KiB
Python
|
from typing import Dict
|
||
|
from typing import Tuple
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
from ..utils.nets_utils import make_pad_mask
|
||
|
|
||
|
|
||
|
class MaskedMSELoss(nn.Module):
|
||
|
def __init__(self, frames_per_step):
|
||
|
super().__init__()
|
||
|
self.frames_per_step = frames_per_step
|
||
|
self.mel_loss_criterion = nn.MSELoss(reduction='none')
|
||
|
# self.loss = nn.MSELoss()
|
||
|
self.stop_loss_criterion = nn.BCEWithLogitsLoss(reduction='none')
|
||
|
|
||
|
def get_mask(self, lengths, max_len=None):
|
||
|
# lengths: [B,]
|
||
|
if max_len is None:
|
||
|
max_len = torch.max(lengths)
|
||
|
batch_size = lengths.size(0)
|
||
|
seq_range = torch.arange(0, max_len).long()
|
||
|
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len).to(lengths.device)
|
||
|
seq_length_expand = lengths.unsqueeze(1).expand_as(seq_range_expand)
|
||
|
return (seq_range_expand < seq_length_expand).float()
|
||
|
|
||
|
def forward(self, mel_pred, mel_pred_postnet, mel_trg, lengths,
|
||
|
stop_target, stop_pred):
|
||
|
## process stop_target
|
||
|
B = stop_target.size(0)
|
||
|
stop_target = stop_target.reshape(B, -1, self.frames_per_step)[:, :, 0]
|
||
|
stop_lengths = torch.ceil(lengths.float() / self.frames_per_step).long()
|
||
|
stop_mask = self.get_mask(stop_lengths, int(mel_trg.size(1)/self.frames_per_step))
|
||
|
|
||
|
mel_trg.requires_grad = False
|
||
|
# (B, T, 1)
|
||
|
mel_mask = self.get_mask(lengths, mel_trg.size(1)).unsqueeze(-1)
|
||
|
# (B, T, D)
|
||
|
mel_mask = mel_mask.expand_as(mel_trg)
|
||
|
mel_loss_pre = (self.mel_loss_criterion(mel_pred, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||
|
mel_loss_post = (self.mel_loss_criterion(mel_pred_postnet, mel_trg) * mel_mask).sum() / mel_mask.sum()
|
||
|
|
||
|
mel_loss = mel_loss_pre + mel_loss_post
|
||
|
|
||
|
# stop token loss
|
||
|
stop_loss = torch.sum(self.stop_loss_criterion(stop_pred, stop_target) * stop_mask) / stop_mask.sum()
|
||
|
|
||
|
return mel_loss, stop_loss
|