diff --git a/mkgui/base/ui/streamlit_ui.py b/mkgui/base/ui/streamlit_ui.py index 2e5159d..479fe1c 100644 --- a/mkgui/base/ui/streamlit_ui.py +++ b/mkgui/base/ui/streamlit_ui.py @@ -815,6 +815,9 @@ def getOpyrator(mode: str) -> Opyrator: if mode == None or mode.startswith('模型训练'): from mkgui.train import train return Opyrator(train) + if mode == None or mode.startswith('模型训练(VC)'): + from mkgui.train_vc import train_vc + return Opyrator(train_vc) from mkgui.app import synthesize return Opyrator(synthesize) @@ -829,7 +832,7 @@ def render_streamlit_ui() -> None: with st.spinner("Loading MockingBird GUI. Please wait..."): session_state.mode = st.sidebar.selectbox( '模式选择', - ( "AI拟音", "VC拟音", "预处理", "模型训练") + ( "AI拟音", "VC拟音", "预处理", "模型训练", "模型训练(VC)") ) if "mode" in session_state: mode = session_state.mode diff --git a/mkgui/train.py b/mkgui/train.py index 5cb3455..7104d54 100644 --- a/mkgui/train.py +++ b/mkgui/train.py @@ -2,66 +2,55 @@ from pydantic import BaseModel, Field import os from pathlib import Path from enum import Enum -from typing import Any, Tuple -import numpy as np -from utils.load_yaml import HpsYaml -from utils.util import AttrDict -import torch +from typing import Any +from synthesizer.hparams import hparams +from synthesizer.train import train as synt_train -# TODO: seperator for *unix systems # Constants -EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" -CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models" +SYN_MODELS_DIRT = f"synthesizer{os.sep}saved_models" ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" -if os.path.isdir(EXT_MODELS_DIRT): - extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) - print("Loaded extractor models: " + str(len(extractors))) -else: - raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") +# EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" +# CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models" +# ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" -if os.path.isdir(CONV_MODELS_DIRT): - convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) - print("Loaded convertor models: " + str(len(convertors))) +# Pre-Load models +if os.path.isdir(SYN_MODELS_DIRT): + synthesizers = Enum('synthesizers', list((file.name, file) for file in Path(SYN_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded synthesizer models: " + str(len(synthesizers))) else: - raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.") + raise Exception(f"Model folder {SYN_MODELS_DIRT} doesn't exist.") if os.path.isdir(ENC_MODELS_DIRT): - encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) print("Loaded encoders models: " + str(len(encoders))) else: raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") class Model(str, Enum): - VC_PPG2MEL = "ppg2mel" - -class Dataset(str, Enum): - AIDATATANG_200ZH = "aidatatang_200zh" - AIDATATANG_200ZH_S = "aidatatang_200zh_s" + DEFAULT = "default" class Input(BaseModel): - # def render_input_ui(st, input) -> Dict: - # input["selected_dataset"] = st.selectbox( - # '选择数据集', - # ("aidatatang_200zh", "aidatatang_200zh_s") - # ) - # return input model: Model = Field( - Model.VC_PPG2MEL, title="模型类型", + Model.DEFAULT, title="模型类型", ) # datasets_root: str = Field( # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", # format=True, # example="..\\trainning_data\\" # ) - output_root: str = Field( - ..., alias="输出目录(可选)", description="建议不填,保持默认", + input_root: str = Field( + ..., alias="输入目录", description="预处理数据根目录", format=True, - example="" + example=f"..{os.sep}audiodata{os.sep}SV2TTS{os.sep}synthesizer" ) - continue_mode: bool = Field( - True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练", + run_id: str = Field( + "", alias="新模型名/运行ID", description="使用新ID进行重新训练,否则选择下面的模型进行继续训练", + ) + synthesizer: synthesizers = Field( + ..., alias="已有合成模型", + description="选择语音合成模型文件." ) gpu: bool = Field( True, alias="GPU训练", description="选择“是”,则使用GPU训练", @@ -69,32 +58,18 @@ class Input(BaseModel): verbose: bool = Field( True, alias="打印详情", description="选择“是”,输出更多详情", ) - # TODO: Move to hiden fields by default - convertor: convertors = Field( - ..., alias="转换模型", - description="选择语音转换模型文件." - ) - extractor: extractors = Field( - ..., alias="特征提取模型", - description="选择PPG特征提取模型文件." - ) encoder: encoders = Field( ..., alias="语音编码模型", description="选择语音编码模型文件." ) - njobs: int = Field( - 8, alias="进程数", description="适用于ppg2mel", + save_every: int = Field( + 1000, alias="更新间隔", description="每隔n步则更新一次模型", ) - seed: int = Field( - default=0, alias="初始随机数", description="适用于ppg2mel", + backup_every: int = Field( + 10000, alias="保存间隔", description="每隔n步则保存一次模型", ) - model_name: str = Field( - ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效", - example="test" - ) - model_config: str = Field( - ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效", - example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2" + log_every: int = Field( + 500, alias="打印间隔", description="每隔n步则打印一次训练统计", ) class AudioEntity(BaseModel): @@ -102,55 +77,30 @@ class AudioEntity(BaseModel): mel: Any class Output(BaseModel): - __root__: Tuple[str, int] + __root__: int - def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + def render_output_ui(self, streamlit_app) -> None: # type: ignore """Custom output UI. If this method is implmeneted, it will be used instead of the default Output UI renderer. """ - sr, count = self.__root__ - streamlit_app.subheader(f"Dataset {sr} done processed total of {count}") + streamlit_app.subheader(f"Training started with code: {self.__root__}") def train(input: Input) -> Output: """Train(训练)""" - print(">>> OneShot VC training ...") - params = AttrDict() - params.update({ - "gpu": input.gpu, - "cpu": not input.gpu, - "njobs": input.njobs, - "seed": input.seed, - "verbose": input.verbose, - "load": input.convertor.value, - "warm_start": False, - }) - if input.continue_mode: - # trace old model and config - p = Path(input.convertor.value) - params.name = p.parent.name - # search a config file - model_config_fpaths = list(p.parent.rglob("*.yaml")) - if len(model_config_fpaths) == 0: - raise "No model yaml config found for convertor" - config = HpsYaml(model_config_fpaths[0]) - params.ckpdir = p.parent.parent - params.config = model_config_fpaths[0] - params.logdir = os.path.join(p.parent, "log") - else: - # Make the config dict dot visitable - config = HpsYaml(input.config) - np.random.seed(input.seed) - torch.manual_seed(input.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(input.seed) - mode = "train" - from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver - solver = Solver(config, params, mode) - solver.load_data() - solver.set_model() - solver.exec() - print(">>> Oneshot VC train finished!") - - # TODO: pass useful return code - return Output(__root__=(input.dataset, 0)) \ No newline at end of file + print(">>> Start training ...") + force_restart = len(input.run_id) > 0 + if not force_restart: + input.run_id = Path(input.synthesizer.value).name.split('.')[0] + + synt_train( + input.run_id, + input.input_root, + f"synthesizer{os.sep}saved_models", + input.save_every, + input.backup_every, + input.log_every, + force_restart, + hparams + ) + return Output(__root__=0) \ No newline at end of file diff --git a/mkgui/train_vc.py b/mkgui/train_vc.py new file mode 100644 index 0000000..8c23372 --- /dev/null +++ b/mkgui/train_vc.py @@ -0,0 +1,155 @@ +from pydantic import BaseModel, Field +import os +from pathlib import Path +from enum import Enum +from typing import Any, Tuple +import numpy as np +from utils.load_yaml import HpsYaml +from utils.util import AttrDict +import torch + +# Constants +EXT_MODELS_DIRT = f"ppg_extractor{os.sep}saved_models" +CONV_MODELS_DIRT = f"ppg2mel{os.sep}saved_models" +ENC_MODELS_DIRT = f"encoder{os.sep}saved_models" + + +if os.path.isdir(EXT_MODELS_DIRT): + extractors = Enum('extractors', list((file.name, file) for file in Path(EXT_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded extractor models: " + str(len(extractors))) +else: + raise Exception(f"Model folder {EXT_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(CONV_MODELS_DIRT): + convertors = Enum('convertors', list((file.name, file) for file in Path(CONV_MODELS_DIRT).glob("**/*.pth"))) + print("Loaded convertor models: " + str(len(convertors))) +else: + raise Exception(f"Model folder {CONV_MODELS_DIRT} doesn't exist.") + +if os.path.isdir(ENC_MODELS_DIRT): + encoders = Enum('encoders', list((file.name, file) for file in Path(ENC_MODELS_DIRT).glob("**/*.pt"))) + print("Loaded encoders models: " + str(len(encoders))) +else: + raise Exception(f"Model folder {ENC_MODELS_DIRT} doesn't exist.") + +class Model(str, Enum): + VC_PPG2MEL = "ppg2mel" + +class Dataset(str, Enum): + AIDATATANG_200ZH = "aidatatang_200zh" + AIDATATANG_200ZH_S = "aidatatang_200zh_s" + +class Input(BaseModel): + # def render_input_ui(st, input) -> Dict: + # input["selected_dataset"] = st.selectbox( + # '选择数据集', + # ("aidatatang_200zh", "aidatatang_200zh_s") + # ) + # return input + model: Model = Field( + Model.VC_PPG2MEL, title="模型类型", + ) + # datasets_root: str = Field( + # ..., alias="预处理数据根目录", description="输入目录(相对/绝对),不适用于ppg2mel模型", + # format=True, + # example="..\\trainning_data\\" + # ) + output_root: str = Field( + ..., alias="输出目录(可选)", description="建议不填,保持默认", + format=True, + example="" + ) + continue_mode: bool = Field( + True, alias="继续训练模式", description="选择“是”,则从下面选择的模型中继续训练", + ) + gpu: bool = Field( + True, alias="GPU训练", description="选择“是”,则使用GPU训练", + ) + verbose: bool = Field( + True, alias="打印详情", description="选择“是”,输出更多详情", + ) + # TODO: Move to hiden fields by default + convertor: convertors = Field( + ..., alias="转换模型", + description="选择语音转换模型文件." + ) + extractor: extractors = Field( + ..., alias="特征提取模型", + description="选择PPG特征提取模型文件." + ) + encoder: encoders = Field( + ..., alias="语音编码模型", + description="选择语音编码模型文件." + ) + njobs: int = Field( + 8, alias="进程数", description="适用于ppg2mel", + ) + seed: int = Field( + default=0, alias="初始随机数", description="适用于ppg2mel", + ) + model_name: str = Field( + ..., alias="新模型名", description="仅在重新训练时生效,选中继续训练时无效", + example="test" + ) + model_config: str = Field( + ..., alias="新模型配置", description="仅在重新训练时生效,选中继续训练时无效", + example=".\\ppg2mel\\saved_models\\seq2seq_mol_ppg2mel_vctk_libri_oneshotvc_r4_normMel_v2" + ) + +class AudioEntity(BaseModel): + content: bytes + mel: Any + +class Output(BaseModel): + __root__: Tuple[str, int] + + def render_output_ui(self, streamlit_app, input) -> None: # type: ignore + """Custom output UI. + If this method is implmeneted, it will be used instead of the default Output UI renderer. + """ + sr, count = self.__root__ + streamlit_app.subheader(f"Dataset {sr} done processed total of {count}") + +def train_vc(input: Input) -> Output: + """Train VC(训练 VC)""" + + print(">>> OneShot VC training ...") + params = AttrDict() + params.update({ + "gpu": input.gpu, + "cpu": not input.gpu, + "njobs": input.njobs, + "seed": input.seed, + "verbose": input.verbose, + "load": input.convertor.value, + "warm_start": False, + }) + if input.continue_mode: + # trace old model and config + p = Path(input.convertor.value) + params.name = p.parent.name + # search a config file + model_config_fpaths = list(p.parent.rglob("*.yaml")) + if len(model_config_fpaths) == 0: + raise "No model yaml config found for convertor" + config = HpsYaml(model_config_fpaths[0]) + params.ckpdir = p.parent.parent + params.config = model_config_fpaths[0] + params.logdir = os.path.join(p.parent, "log") + else: + # Make the config dict dot visitable + config = HpsYaml(input.config) + np.random.seed(input.seed) + torch.manual_seed(input.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(input.seed) + mode = "train" + from ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + solver = Solver(config, params, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + + # TODO: pass useful return code + return Output(__root__=(input.dataset, 0)) \ No newline at end of file diff --git a/synthesizer/models/tacotron.py b/synthesizer/models/tacotron.py index 1fdc064..9cfabf7 100644 --- a/synthesizer/models/tacotron.py +++ b/synthesizer/models/tacotron.py @@ -1,4 +1,5 @@ import os +from matplotlib.pyplot import step import numpy as np import torch import torch.nn as nn @@ -297,7 +298,7 @@ class Decoder(nn.Module): x = torch.cat([context_vec, attn_hidden], dim=1) x = self.rnn_input(x) - # Compute first Residual RNN + # Compute first Residual RNN, training with fixed zoneout rate 0.1 rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell)) if self.training: rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next,device=device) @@ -372,11 +373,15 @@ class Tacotron(nn.Module): outputs = torch.cat([outputs, speaker_embeddings_], dim=-1) return outputs - def forward(self, texts, mels, speaker_embedding): + def forward(self, texts, mels, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): + device = texts.device # use same device as parameters - self.step += 1 - batch_size, _, steps = mels.size() + if self.training: + self.step += 1 + batch_size, _, steps = mels.size() + else: + batch_size, _ = texts.size() # Initialise all hidden states and pack into tuple attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) @@ -401,11 +406,22 @@ class Tacotron(nn.Module): # SV2TTS: Run the encoder with the speaker embedding # The projection avoids unnecessary matmuls in the decoder loop encoder_seq = self.encoder(texts, speaker_embedding) - # put after encoder + if hparams.use_gst and self.gst is not None: - style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced - # style_embed = style_embed.expand_as(encoder_seq) - # encoder_seq = torch.cat((encoder_seq, style_embed), 2) + if self.training: + style_embed = self.gst(speaker_embedding, speaker_embedding) # for training, speaker embedding can represent both style inputs and referenced + # style_embed = style_embed.expand_as(encoder_seq) + # encoder_seq = torch.cat((encoder_seq, style_embed), 2) + elif style_idx >= 0 and style_idx < 10: + query = torch.zeros(1, 1, self.gst.stl.attention.num_units) + if device.type == 'cuda': + query = query.cuda() + gst_embed = torch.tanh(self.gst.stl.embed) + key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) + style_embed = self.gst.stl.attention(query, key) + else: + speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) + style_embed = self.gst(speaker_embedding_style, speaker_embedding) encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) encoder_seq_proj = self.encoder_proj(encoder_seq) @@ -414,13 +430,17 @@ class Tacotron(nn.Module): # Run the decoder loop for t in range(0, steps, self.r): - prenet_in = mels[:, :, t - 1] if t > 0 else go_frame + if self.training: + prenet_in = mels[:, :, t -1] if t > 0 else go_frame + else: + prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ self.decoder(encoder_seq, encoder_seq_proj, prenet_in, hidden_states, cell_states, context_vec, t, texts) mel_outputs.append(mel_frames) attn_scores.append(scores) stop_outputs.extend([stop_tokens] * self.r) + if not self.training and (stop_tokens * 10 > min_stop_token).all() and t > 10: break # Concat the mel outputs into sequence mel_outputs = torch.cat(mel_outputs, dim=2) @@ -435,87 +455,93 @@ class Tacotron(nn.Module): # attn_scores = attn_scores.cpu().data.numpy() stop_outputs = torch.cat(stop_outputs, 1) + + if self.training: + self.train() + return mel_outputs, linear, attn_scores, stop_outputs def generate(self, x, speaker_embedding=None, steps=2000, style_idx=0, min_stop_token=5): self.eval() - device = x.device # use same device as parameters - - batch_size, _ = x.size() - - # Need to initialise all hidden states and pack into tuple for tidyness - attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) - rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) - rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) - hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) - - # Need to initialise all lstm cell states and pack into tuple for tidyness - rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) - rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) - cell_states = (rnn1_cell, rnn2_cell) - - # Need a Frame for start of decoder loop - go_frame = torch.zeros(batch_size, self.n_mels, device=device) - - # Need an initial context vector - size = self.encoder_dims + self.speaker_embedding_size - if hparams.use_gst: - size += gst_hp.E - context_vec = torch.zeros(batch_size, size, device=device) - - # SV2TTS: Run the encoder with the speaker embedding - # The projection avoids unnecessary matmuls in the decoder loop - encoder_seq = self.encoder(x, speaker_embedding) - - # put after encoder - if hparams.use_gst and self.gst is not None: - if style_idx >= 0 and style_idx < 10: - query = torch.zeros(1, 1, self.gst.stl.attention.num_units) - if device.type == 'cuda': - query = query.cuda() - gst_embed = torch.tanh(self.gst.stl.embed) - key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) - style_embed = self.gst.stl.attention(query, key) - else: - speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) - style_embed = self.gst(speaker_embedding_style, speaker_embedding) - encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) - # style_embed = style_embed.expand_as(encoder_seq) - # encoder_seq = torch.cat((encoder_seq, style_embed), 2) - encoder_seq_proj = self.encoder_proj(encoder_seq) - - # Need a couple of lists for outputs - mel_outputs, attn_scores, stop_outputs = [], [], [] - - # Run the decoder loop - for t in range(0, steps, self.r): - prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame - mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ - self.decoder(encoder_seq, encoder_seq_proj, prenet_in, - hidden_states, cell_states, context_vec, t, x) - mel_outputs.append(mel_frames) - attn_scores.append(scores) - stop_outputs.extend([stop_tokens] * self.r) - # Stop the loop when all stop tokens in batch exceed threshold - if (stop_tokens * 10 > min_stop_token).all() and t > 10: break - - # Concat the mel outputs into sequence - mel_outputs = torch.cat(mel_outputs, dim=2) - - # Post-Process for Linear Spectrograms - postnet_out = self.postnet(mel_outputs) - linear = self.post_proj(postnet_out) - - - linear = linear.transpose(1, 2) - - # For easy visualisation - attn_scores = torch.cat(attn_scores, 1) - stop_outputs = torch.cat(stop_outputs, 1) - - self.train() - + mel_outputs, linear, attn_scores, _ = self.forward(x, None, speaker_embedding, steps, style_idx, min_stop_token) return mel_outputs, linear, attn_scores + # device = x.device # use same device as parameters + + # batch_size, _ = x.size() + + # # Need to initialise all hidden states and pack into tuple for tidyness + # attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device) + # rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + # rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device) + # hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden) + + # # Need to initialise all lstm cell states and pack into tuple for tidyness + # rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + # rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device) + # cell_states = (rnn1_cell, rnn2_cell) + + # # Need a Frame for start of decoder loop + # go_frame = torch.zeros(batch_size, self.n_mels, device=device) + + # # Need an initial context vector + # size = self.encoder_dims + self.speaker_embedding_size + # if hparams.use_gst: + # size += gst_hp.E + # context_vec = torch.zeros(batch_size, size, device=device) + + # # SV2TTS: Run the encoder with the speaker embedding + # # The projection avoids unnecessary matmuls in the decoder loop + # encoder_seq = self.encoder(x, speaker_embedding) + + # # put after encoder + # if hparams.use_gst and self.gst is not None: + # if style_idx >= 0 and style_idx < 10: + # query = torch.zeros(1, 1, self.gst.stl.attention.num_units) + # if device.type == 'cuda': + # query = query.cuda() + # gst_embed = torch.tanh(self.gst.stl.embed) + # key = gst_embed[style_idx].unsqueeze(0).expand(1, -1, -1) + # style_embed = self.gst.stl.attention(query, key) + # else: + # speaker_embedding_style = torch.zeros(speaker_embedding.size()[0], 1, self.speaker_embedding_size).to(device) + # style_embed = self.gst(speaker_embedding_style, speaker_embedding) + # encoder_seq = self._concat_speaker_embedding(encoder_seq, style_embed) + # # style_embed = style_embed.expand_as(encoder_seq) + # # encoder_seq = torch.cat((encoder_seq, style_embed), 2) + # encoder_seq_proj = self.encoder_proj(encoder_seq) + + # # Need a couple of lists for outputs + # mel_outputs, attn_scores, stop_outputs = [], [], [] + + # # Run the decoder loop + # for t in range(0, steps, self.r): + # prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame + # mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \ + # self.decoder(encoder_seq, encoder_seq_proj, prenet_in, + # hidden_states, cell_states, context_vec, t, x) + # mel_outputs.append(mel_frames) + # attn_scores.append(scores) + # stop_outputs.extend([stop_tokens] * self.r) + # # Stop the loop when all stop tokens in batch exceed threshold + # if (stop_tokens * 10 > min_stop_token).all() and t > 10: break + + # # Concat the mel outputs into sequence + # mel_outputs = torch.cat(mel_outputs, dim=2) + + # # Post-Process for Linear Spectrograms + # postnet_out = self.postnet(mel_outputs) + # linear = self.post_proj(postnet_out) + + + # linear = linear.transpose(1, 2) + + # # For easy visualisation + # attn_scores = torch.cat(attn_scores, 1) + # stop_outputs = torch.cat(stop_outputs, 1) + + # self.train() + + # return mel_outputs, linear, attn_scores def init_model(self): for p in self.parameters(): diff --git a/synthesizer/train.py b/synthesizer/train.py index 2532348..8799e84 100644 --- a/synthesizer/train.py +++ b/synthesizer/train.py @@ -15,9 +15,8 @@ from datetime import datetime import json import numpy as np from pathlib import Path -import sys import time - +import os def np_now(x: torch.Tensor): return x.detach().cpu().numpy() @@ -265,7 +264,19 @@ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int, loss=loss, hparams=hparams, sw=sw) - + MAX_SAVED_COUNT = 20 + if (step / hparams.tts_eval_interval) % MAX_SAVED_COUNT: + # clean up and save last MAX_SAVED_COUNT; + plots = next(os.walk(plot_dir), (None, None, []))[2] + for plot in plots[-MAX_SAVED_COUNT:]: + os.remove(plot_dir.joinpath(plot)) + mel_files = next(os.walk(mel_output_dir), (None, None, []))[2] + for mel_file in mel_files[-MAX_SAVED_COUNT:]: + os.remove(mel_output_dir.joinpath(mel_file)) + wavs = next(os.walk(wav_dir), (None, None, []))[2] + for w in wavs[-MAX_SAVED_COUNT:]: + os.remove(wav_dir.joinpath(w)) + # Break out of loop to update training schedule if step >= max_step: break