mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
74a3fc97d0
Need readme
289 lines
11 KiB
Python
289 lines
11 KiB
Python
import os, sys
|
|
# sys.path.append('/home/shaunxliu/projects/nnsp')
|
|
import matplotlib
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from matplotlib.ticker import MaxNLocator
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
import numpy as np
|
|
from .solver import BaseSolver
|
|
from utils.data_load import OneshotVcDataset, MultiSpkVcCollate
|
|
# from src.rnn_ppg2mel import BiRnnPpg2MelModel
|
|
# from src.mel_decoder_mol_encAddlf0 import MelDecoderMOL
|
|
from .loss import MaskedMSELoss
|
|
from .optim import Optimizer
|
|
from utils.util import human_format
|
|
from models.ppg2mel import MelDecoderMOLv2
|
|
|
|
|
|
class Solver(BaseSolver):
|
|
"""Customized Solver."""
|
|
def __init__(self, config, paras, mode):
|
|
super().__init__(config, paras, mode)
|
|
self.num_att_plots = 5
|
|
self.att_ws_dir = f"{self.logdir}/att_ws"
|
|
os.makedirs(self.att_ws_dir, exist_ok=True)
|
|
self.best_loss = np.inf
|
|
|
|
def fetch_data(self, data):
|
|
"""Move data to device"""
|
|
data = [i.to(self.device) for i in data]
|
|
return data
|
|
|
|
def load_data(self):
|
|
""" Load data for training/validation/plotting."""
|
|
train_dataset = OneshotVcDataset(
|
|
meta_file=self.config.data.train_fid_list,
|
|
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
|
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
|
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
|
libri_f0_dir=self.config.data.libri_f0_dir,
|
|
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
|
libri_wav_dir=self.config.data.libri_wav_dir,
|
|
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
|
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
|
ppg_file_ext=self.config.data.ppg_file_ext,
|
|
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
|
mel_min=self.config.data.mel_min,
|
|
mel_max=self.config.data.mel_max,
|
|
)
|
|
dev_dataset = OneshotVcDataset(
|
|
meta_file=self.config.data.dev_fid_list,
|
|
vctk_ppg_dir=self.config.data.vctk_ppg_dir,
|
|
libri_ppg_dir=self.config.data.libri_ppg_dir,
|
|
vctk_f0_dir=self.config.data.vctk_f0_dir,
|
|
libri_f0_dir=self.config.data.libri_f0_dir,
|
|
vctk_wav_dir=self.config.data.vctk_wav_dir,
|
|
libri_wav_dir=self.config.data.libri_wav_dir,
|
|
vctk_spk_dvec_dir=self.config.data.vctk_spk_dvec_dir,
|
|
libri_spk_dvec_dir=self.config.data.libri_spk_dvec_dir,
|
|
ppg_file_ext=self.config.data.ppg_file_ext,
|
|
min_max_norm_mel=self.config.data.min_max_norm_mel,
|
|
mel_min=self.config.data.mel_min,
|
|
mel_max=self.config.data.mel_max,
|
|
)
|
|
self.train_dataloader = DataLoader(
|
|
train_dataset,
|
|
num_workers=self.paras.njobs,
|
|
shuffle=True,
|
|
batch_size=self.config.hparas.batch_size,
|
|
pin_memory=False,
|
|
drop_last=True,
|
|
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
|
use_spk_dvec=True),
|
|
)
|
|
self.dev_dataloader = DataLoader(
|
|
dev_dataset,
|
|
num_workers=self.paras.njobs,
|
|
shuffle=False,
|
|
batch_size=self.config.hparas.batch_size,
|
|
pin_memory=False,
|
|
drop_last=False,
|
|
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
|
use_spk_dvec=True),
|
|
)
|
|
self.plot_dataloader = DataLoader(
|
|
dev_dataset,
|
|
num_workers=self.paras.njobs,
|
|
shuffle=False,
|
|
batch_size=1,
|
|
pin_memory=False,
|
|
drop_last=False,
|
|
collate_fn=MultiSpkVcCollate(self.config.model.frames_per_step,
|
|
use_spk_dvec=True,
|
|
give_uttids=True),
|
|
)
|
|
msg = "Have prepared training set and dev set."
|
|
self.verbose(msg)
|
|
|
|
def load_pretrained_params(self):
|
|
print("Load pretrained model from: ", self.config.data.pretrain_model_file)
|
|
ignore_layer_prefixes = ["speaker_embedding_table"]
|
|
pretrain_model_file = self.config.data.pretrain_model_file
|
|
pretrain_ckpt = torch.load(
|
|
pretrain_model_file, map_location=self.device
|
|
)["model"]
|
|
model_dict = self.model.state_dict()
|
|
print(self.model)
|
|
|
|
# 1. filter out unnecessrary keys
|
|
for prefix in ignore_layer_prefixes:
|
|
pretrain_ckpt = {k : v
|
|
for k, v in pretrain_ckpt.items() if not k.startswith(prefix)
|
|
}
|
|
# 2. overwrite entries in the existing state dict
|
|
model_dict.update(pretrain_ckpt)
|
|
|
|
# 3. load the new state dict
|
|
self.model.load_state_dict(model_dict)
|
|
|
|
def set_model(self):
|
|
"""Setup model and optimizer"""
|
|
# Model
|
|
print("[INFO] Model name: ", self.config["model_name"])
|
|
self.model = MelDecoderMOLv2(
|
|
**self.config["model"]
|
|
).to(self.device)
|
|
# self.load_pretrained_params()
|
|
|
|
# model_params = [{'params': self.model.spk_embedding.weight}]
|
|
model_params = [{'params': self.model.parameters()}]
|
|
|
|
# Loss criterion
|
|
self.loss_criterion = MaskedMSELoss(self.config.model.frames_per_step)
|
|
|
|
# Optimizer
|
|
self.optimizer = Optimizer(model_params, **self.config["hparas"])
|
|
self.verbose(self.optimizer.create_msg())
|
|
|
|
# Automatically load pre-trained model if self.paras.load is given
|
|
self.load_ckpt()
|
|
|
|
def exec(self):
|
|
self.verbose("Total training steps {}.".format(
|
|
human_format(self.max_step)))
|
|
|
|
mel_loss = None
|
|
n_epochs = 0
|
|
# Set as current time
|
|
self.timer.set()
|
|
|
|
while self.step < self.max_step:
|
|
for data in self.train_dataloader:
|
|
# Pre-step: updata lr_rate and do zero_grad
|
|
lr_rate = self.optimizer.pre_step(self.step)
|
|
total_loss = 0
|
|
# data to device
|
|
ppgs, lf0_uvs, mels, in_lengths, \
|
|
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
|
self.timer.cnt("rd")
|
|
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
|
ppgs,
|
|
in_lengths,
|
|
mels,
|
|
out_lengths,
|
|
lf0_uvs,
|
|
spk_ids
|
|
)
|
|
mel_loss, stop_loss = self.loss_criterion(
|
|
mel_outputs,
|
|
mel_outputs_postnet,
|
|
mels,
|
|
out_lengths,
|
|
stop_tokens,
|
|
predicted_stop
|
|
)
|
|
loss = mel_loss + stop_loss
|
|
|
|
self.timer.cnt("fw")
|
|
|
|
# Back-prop
|
|
grad_norm = self.backward(loss)
|
|
self.step += 1
|
|
|
|
# Logger
|
|
if (self.step == 1) or (self.step % self.PROGRESS_STEP == 0):
|
|
self.progress("Tr|loss:{:.4f},mel-loss:{:.4f},stop-loss:{:.4f}|Grad.Norm-{:.2f}|{}"
|
|
.format(loss.cpu().item(), mel_loss.cpu().item(),
|
|
stop_loss.cpu().item(), grad_norm, self.timer.show()))
|
|
self.write_log('loss', {'tr/loss': loss,
|
|
'tr/mel-loss': mel_loss,
|
|
'tr/stop-loss': stop_loss})
|
|
|
|
# Validation
|
|
if (self.step == 1) or (self.step % self.valid_step == 0):
|
|
self.validate()
|
|
|
|
# End of step
|
|
# https://github.com/pytorch/pytorch/issues/13246#issuecomment-529185354
|
|
torch.cuda.empty_cache()
|
|
self.timer.set()
|
|
if self.step > self.max_step:
|
|
break
|
|
n_epochs += 1
|
|
self.log.close()
|
|
|
|
def validate(self):
|
|
self.model.eval()
|
|
dev_loss, dev_mel_loss, dev_stop_loss = 0.0, 0.0, 0.0
|
|
|
|
for i, data in enumerate(self.dev_dataloader):
|
|
self.progress('Valid step - {}/{}'.format(i+1, len(self.dev_dataloader)))
|
|
# Fetch data
|
|
ppgs, lf0_uvs, mels, in_lengths, \
|
|
out_lengths, spk_ids, stop_tokens = self.fetch_data(data)
|
|
with torch.no_grad():
|
|
mel_outputs, mel_outputs_postnet, predicted_stop = self.model(
|
|
ppgs,
|
|
in_lengths,
|
|
mels,
|
|
out_lengths,
|
|
lf0_uvs,
|
|
spk_ids
|
|
)
|
|
mel_loss, stop_loss = self.loss_criterion(
|
|
mel_outputs,
|
|
mel_outputs_postnet,
|
|
mels,
|
|
out_lengths,
|
|
stop_tokens,
|
|
predicted_stop
|
|
)
|
|
loss = mel_loss + stop_loss
|
|
|
|
dev_loss += loss.cpu().item()
|
|
dev_mel_loss += mel_loss.cpu().item()
|
|
dev_stop_loss += stop_loss.cpu().item()
|
|
|
|
dev_loss = dev_loss / (i + 1)
|
|
dev_mel_loss = dev_mel_loss / (i + 1)
|
|
dev_stop_loss = dev_stop_loss / (i + 1)
|
|
self.save_checkpoint(f'step_{self.step}.pth', 'loss', dev_loss, show_msg=False)
|
|
if dev_loss < self.best_loss:
|
|
self.best_loss = dev_loss
|
|
self.save_checkpoint(f'best_loss_step_{self.step}.pth', 'loss', dev_loss)
|
|
self.write_log('loss', {'dv/loss': dev_loss,
|
|
'dv/mel-loss': dev_mel_loss,
|
|
'dv/stop-loss': dev_stop_loss})
|
|
|
|
# plot attention
|
|
for i, data in enumerate(self.plot_dataloader):
|
|
if i == self.num_att_plots:
|
|
break
|
|
# Fetch data
|
|
ppgs, lf0_uvs, mels, in_lengths, \
|
|
out_lengths, spk_ids, stop_tokens = self.fetch_data(data[:-1])
|
|
fid = data[-1][0]
|
|
with torch.no_grad():
|
|
_, _, _, att_ws = self.model(
|
|
ppgs,
|
|
in_lengths,
|
|
mels,
|
|
out_lengths,
|
|
lf0_uvs,
|
|
spk_ids,
|
|
output_att_ws=True
|
|
)
|
|
att_ws = att_ws.squeeze(0).cpu().numpy()
|
|
att_ws = att_ws[None]
|
|
w, h = plt.figaspect(1.0 / len(att_ws))
|
|
fig = plt.Figure(figsize=(w * 1.3, h * 1.3))
|
|
axes = fig.subplots(1, len(att_ws))
|
|
if len(att_ws) == 1:
|
|
axes = [axes]
|
|
|
|
for ax, aw in zip(axes, att_ws):
|
|
ax.imshow(aw.astype(np.float32), aspect="auto")
|
|
ax.set_title(f"{fid}")
|
|
ax.set_xlabel("Input")
|
|
ax.set_ylabel("Output")
|
|
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
|
|
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
|
|
fig_name = f"{self.att_ws_dir}/{fid}_step{self.step}.png"
|
|
fig.savefig(fig_name)
|
|
|
|
# Resume training
|
|
self.model.train()
|
|
|