diff --git a/.vscode/launch.json b/.vscode/launch.json index 4cc1daf..b2ab7f8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -15,7 +15,8 @@ "name": "Python: Vocoder Preprocess", "type": "python", "request": "launch", - "program": "vocoder_preprocess.py", + "program": "control\\cli\\vocoder_preprocess.py", + "cwd": "${workspaceFolder}", "console": "integratedTerminal", "args": ["..\\audiodata"] }, @@ -23,7 +24,8 @@ "name": "Python: Vocoder Train", "type": "python", "request": "launch", - "program": "vocoder_train.py", + "program": "control\\cli\\vocoder_train.py", + "cwd": "${workspaceFolder}", "console": "integratedTerminal", "args": ["dev", "..\\audiodata"] }, @@ -32,6 +34,7 @@ "type": "python", "request": "launch", "program": "demo_toolbox.py", + "cwd": "${workspaceFolder}", "console": "integratedTerminal", "args": ["-d","..\\audiodata"] }, @@ -40,6 +43,7 @@ "type": "python", "request": "launch", "program": "demo_toolbox.py", + "cwd": "${workspaceFolder}", "console": "integratedTerminal", "args": ["-d","..\\audiodata","-vc"] }, @@ -47,9 +51,9 @@ "name": "Python: Synth Train", "type": "python", "request": "launch", - "program": "synthesizer_train.py", + "program": "train.py", "console": "integratedTerminal", - "args": ["my_run", "..\\"] + "args": ["--type", "synth", "..\\audiodata\\SV2TTS\\synthesizer"] }, { "name": "Python: PPG Convert", diff --git a/control/cli/__init__.py b/control/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/control/cli/synthesizer_preprocess_embeds.py b/control/cli/synthesizer_preprocess_embeds.py index 1fea8bc..38c9e2f 100644 --- a/control/cli/synthesizer_preprocess_embeds.py +++ b/control/cli/synthesizer_preprocess_embeds.py @@ -14,7 +14,7 @@ if __name__ == "__main__": "Path to the synthesizer training data that contains the audios and the train.txt file. " "If you let everything as default, it should be /SV2TTS/synthesizer/.") parser.add_argument("-e", "--encoder_model_fpath", type=Path, - default="encoder/saved_models/pretrained.pt", help=\ + default="data/ckpt/encoder/pretrained.pt", help=\ "Path your trained encoder model.") parser.add_argument("-n", "--n_processes", type=int, default=4, help= \ "Number of parallel processes. An encoder is created for each, so you may need to lower " diff --git a/control/cli/synthesizer_train.py b/control/cli/synthesizer_train.py index 8fb41ae..423cc3b 100644 --- a/control/cli/synthesizer_train.py +++ b/control/cli/synthesizer_train.py @@ -3,8 +3,7 @@ from models.synthesizer.train import train from utils.argutils import print_args import argparse - -if __name__ == "__main__": +def new_train(): parser = argparse.ArgumentParser() parser.add_argument("run_id", type=str, help= \ "Name for this model instance. If a model state from the same run ID was previously " @@ -13,7 +12,7 @@ if __name__ == "__main__": parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \ "Path to the synthesizer directory that contains the ground truth mel spectrograms, " "the wavs and the embeds.") - parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\ + parser.add_argument("-m", "--models_dir", type=str, default=f"data/ckpt/synthesizer/", help=\ "Path to the output directory that will contain the saved model weights and the logs.") parser.add_argument("-s", "--save_every", type=int, default=1000, help= \ "Number of steps between updates of the model on the disk. Set to 0 to never save the " @@ -28,10 +27,14 @@ if __name__ == "__main__": parser.add_argument("--hparams", default="", help="Hyperparameter overrides as a comma-separated list of name=value " "pairs") - args = parser.parse_args() + args, _ = parser.parse_known_args() print_args(args, parser) args.hparams = hparams.parse(args.hparams) # Run the training train(**vars(args)) + + +if __name__ == "__main__": + new_train() \ No newline at end of file diff --git a/control/cli/train_ppg2mel.py b/control/cli/train_ppg2mel.py new file mode 100644 index 0000000..0a94e84 --- /dev/null +++ b/control/cli/train_ppg2mel.py @@ -0,0 +1,66 @@ +import sys +import torch +import argparse +import numpy as np +from utils.load_yaml import HpsYaml +from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver + +# For reproducibility, comment these may speed up training +torch.backends.cudnn.deterministic = True +torch.backends.cudnn.benchmark = False + +def main(): + # Arguments + parser = argparse.ArgumentParser(description= + 'Training PPG2Mel VC model.') + parser.add_argument('--config', type=str, + help='Path to experiment config, e.g., config/vc.yaml') + parser.add_argument('--name', default=None, type=str, help='Name for logging.') + parser.add_argument('--logdir', default='log/', type=str, + help='Logging path.', required=False) + parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str, + help='Checkpoint path.', required=False) + parser.add_argument('--outdir', default='result/', type=str, + help='Decode output path.', required=False) + parser.add_argument('--load', default=None, type=str, + help='Load pre-trained model (for training only)', required=False) + parser.add_argument('--warm_start', action='store_true', + help='Load model weights only, ignore specified layers.') + parser.add_argument('--seed', default=0, type=int, + help='Random seed for reproducable results.', required=False) + parser.add_argument('--njobs', default=8, type=int, + help='Number of threads for dataloader/decoding.', required=False) + parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') + parser.add_argument('--no-pin', action='store_true', + help='Disable pin-memory for dataloader') + parser.add_argument('--test', action='store_true', help='Test the model.') + parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') + parser.add_argument('--finetune', action='store_true', help='Finetune model') + parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') + parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') + parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') + + ### + paras = parser.parse_args() + setattr(paras, 'gpu', not paras.cpu) + setattr(paras, 'pin_memory', not paras.no_pin) + setattr(paras, 'verbose', not paras.no_msg) + # Make the config dict dot visitable + config = HpsYaml(paras.config) + + np.random.seed(paras.seed) + torch.manual_seed(paras.seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(paras.seed) + + print(">>> OneShot VC training ...") + mode = "train" + solver = Solver(config, paras, mode) + solver.load_data() + solver.set_model() + solver.exec() + print(">>> Oneshot VC train finished!") + sys.exit(0) + +if __name__ == "__main__": + main() diff --git a/control/mkgui/app.py b/control/mkgui/app.py index 693b168..aac8a29 100644 --- a/control/mkgui/app.py +++ b/control/mkgui/app.py @@ -46,15 +46,16 @@ else: raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.") - class Input(BaseModel): message: str = Field( ..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容" ) local_audio_file: audio_input_selection = Field( - ..., alias="输入语音(本地wav)", + ..., alias="选择语音(本地wav)", description="选择本地语音文件." ) + record_audio_file: FileContent = Field(default=None, alias="录制语音", + description="录音.", is_recorder=True, mime_type="audio/wav") upload_audio_file: FileContent = Field(default=None, alias="或上传语音", description="拖拽或点击上传.", mime_type="audio/wav") encoder: encoders = Field( @@ -104,7 +105,12 @@ def synthesize(input: Input) -> Output: gan_vocoder.load_model(Path(input.vocoder.value)) # load file - if input.upload_audio_file != None: + if input.record_audio_file != None: + with open(TEMP_SOURCE_AUDIO, "w+b") as f: + f.write(input.record_audio_file.as_bytes()) + f.seek(0) + wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO) + elif input.upload_audio_file != None: with open(TEMP_SOURCE_AUDIO, "w+b") as f: f.write(input.upload_audio_file.as_bytes()) f.seek(0) diff --git a/control/mkgui/base/ui/schema_utils.py b/control/mkgui/base/ui/schema_utils.py index a2be43c..a0ccf20 100644 --- a/control/mkgui/base/ui/schema_utils.py +++ b/control/mkgui/base/ui/schema_utils.py @@ -37,6 +37,12 @@ def is_single_file_property(property: Dict) -> bool: # TODO: binary? return property.get("format") == "byte" +def is_single_autio_property(property: Dict) -> bool: + if property.get("type") != "string": + return False + # TODO: binary? + return property.get("format") == "bytes" + def is_single_directory_property(property: Dict) -> bool: if property.get("type") != "string": diff --git a/control/mkgui/base/ui/streamlit_ui.py b/control/mkgui/base/ui/streamlit_ui.py index b3a18b4..60151fa 100644 --- a/control/mkgui/base/ui/streamlit_ui.py +++ b/control/mkgui/base/ui/streamlit_ui.py @@ -242,7 +242,14 @@ class InputUI: file_extension = None if "mime_type" in property: file_extension = mimetypes.guess_extension(property["mime_type"]) - + + if "is_recorder" in property: + from audio_recorder_streamlit import audio_recorder + audio_bytes = audio_recorder() + if audio_bytes: + streamlit_app.audio(audio_bytes, format="audio/wav") + return audio_bytes + uploaded_file = streamlit_app.file_uploader( **streamlit_kwargs, accept_multiple_files=False, type=file_extension ) @@ -262,6 +269,39 @@ class InputUI: streamlit_app.video(bytes, format=property.get("mime_type")) return bytes + def _render_single_audio_input( + self, streamlit_app: st, key: str, property: Dict + ) -> Any: + # streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property) + from audio_recorder_streamlit import audio_recorder + audio_bytes = audio_recorder() + if audio_bytes: + streamlit_app.audio(audio_bytes, format="audio/wav") + return audio_bytes + + # file_extension = None + # if "mime_type" in property: + # file_extension = mimetypes.guess_extension(property["mime_type"]) + + # uploaded_file = streamlit_app.file_uploader( + # **streamlit_kwargs, accept_multiple_files=False, type=file_extension + # ) + # if uploaded_file is None: + # return None + + # bytes = uploaded_file.getvalue() + # if property.get("mime_type"): + # if is_compatible_audio(property["mime_type"]): + # # Show audio + # streamlit_app.audio(bytes, format=property.get("mime_type")) + # if is_compatible_image(property["mime_type"]): + # # Show image + # streamlit_app.image(bytes) + # if is_compatible_video(property["mime_type"]): + # # Show video + # streamlit_app.video(bytes, format=property.get("mime_type")) + # return bytes + def _render_single_string_input( self, streamlit_app: st, key: str, property: Dict ) -> Any: @@ -820,7 +860,6 @@ def getOpyrator(mode: str) -> Opyrator: from control.mkgui.app import synthesize return Opyrator(synthesize) - def render_streamlit_ui() -> None: # init session_state = st.session_state @@ -852,6 +891,13 @@ def render_streamlit_ui() -> None: with left: st.header("Control 控制") + # if session_state.mode in ["AI拟音", "VC拟音"] : + # from audiorecorder import audiorecorder + # audio = audiorecorder("Click to record", "Recording...") + # if len(audio) > 0: + # # To play audio in frontend: + # st.audio(audio.tobytes()) + InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st) execute_selected = st.button(opyrator.action) if execute_selected: diff --git a/control/toolbox/__init__.py b/control/toolbox/__init__.py index 9bd8f23..a8b7a10 100644 --- a/control/toolbox/__init__.py +++ b/control/toolbox/__init__.py @@ -38,7 +38,8 @@ recognized_datasets = [ "VoxCeleb2/dev/aac", "VoxCeleb2/test/aac", "VCTK-Corpus/wav48", - "aidatatang_200zh/corpus", + "aidatatang_200zh/corpus/test", + "aidatatang_200zh/corpus/train", "aishell3/test/wav", "magicdata/train", ] diff --git a/control/toolbox/ui.py b/control/toolbox/ui.py index c7b6223..e60d514 100644 --- a/control/toolbox/ui.py +++ b/control/toolbox/ui.py @@ -3,7 +3,6 @@ from PyQt5 import QtGui from PyQt5.QtWidgets import * import matplotlib.pyplot as plt from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas -from matplotlib.figure import Figure from models.encoder.inference import plot_embedding_as_heatmap from control.toolbox.utterance import Utterance from pathlib import Path diff --git a/demo_toolbox.py b/demo_toolbox.py index 3304e23..f24cc3c 100644 --- a/demo_toolbox.py +++ b/demo_toolbox.py @@ -17,15 +17,15 @@ if __name__ == '__main__': "supported datasets.", default=None) parser.add_argument("-vc", "--vc_mode", action="store_true", help="Voice Conversion Mode(PPG based)") - parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models", + parser.add_argument("-e", "--enc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}encoder", help="Directory containing saved encoder models") - parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models", + parser.add_argument("-s", "--syn_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}synthesizer", help="Directory containing saved synthesizer models") - parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models", + parser.add_argument("-v", "--voc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}vocoder", help="Directory containing saved vocoder models") - parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models", + parser.add_argument("-ex", "--extractor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg_extractor", help="Directory containing saved extrator models") - parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models", + parser.add_argument("-cv", "--convertor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg2mel", help="Directory containing saved convert models") parser.add_argument("--cpu", action="store_true", help=\ "If True, processing is done on CPU, even when a GPU is available.") diff --git a/models/vocoder/fregan/generator.py b/models/vocoder/fregan/generator.py index 8f8eedf..73c4c0b 100644 --- a/models/vocoder/fregan/generator.py +++ b/models/vocoder/fregan/generator.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from models.vocoder.fregan.utils import init_weights, get_padding +from utils.util import init_weights, get_padding LRELU_SLOPE = 0.1 diff --git a/models/vocoder/fregan/utils.py b/models/vocoder/fregan/utils.py index 45161b1..9970f00 100644 --- a/models/vocoder/fregan/utils.py +++ b/models/vocoder/fregan/utils.py @@ -27,21 +27,12 @@ def plot_spectrogram(spectrogram): return fig -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - def apply_weight_norm(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: weight_norm(m) -def get_padding(kernel_size, dilation=1): - return int((kernel_size*dilation - dilation)/2) - def load_checkpoint(filepath, device): assert os.path.isfile(filepath) diff --git a/models/vocoder/hifigan/models.py b/models/vocoder/hifigan/models.py index fc46164..6da66ee 100644 --- a/models/vocoder/hifigan/models.py +++ b/models/vocoder/hifigan/models.py @@ -3,7 +3,7 @@ import torch.nn.functional as F import torch.nn as nn from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm -from models.vocoder.hifigan.utils import init_weights, get_padding +from utils.util import init_weights, get_padding LRELU_SLOPE = 0.1 diff --git a/models/vocoder/hifigan/utils.py b/models/vocoder/hifigan/utils.py index e67cbcd..a34ca42 100644 --- a/models/vocoder/hifigan/utils.py +++ b/models/vocoder/hifigan/utils.py @@ -6,7 +6,6 @@ from torch.nn.utils import weight_norm matplotlib.use("Agg") import matplotlib.pylab as plt - def plot_spectrogram(spectrogram): fig, ax = plt.subplots(figsize=(10, 2)) im = ax.imshow(spectrogram, aspect="auto", origin="lower", @@ -19,12 +18,6 @@ def plot_spectrogram(spectrogram): return fig -def init_weights(m, mean=0.0, std=0.01): - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - m.weight.data.normal_(mean, std) - - def apply_weight_norm(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: @@ -55,4 +48,3 @@ def scan_checkpoint(cp_dir, prefix): if len(cp_list) == 0: return None return sorted(cp_list)[-1] - diff --git a/models/vocoder/vocoder_dataset.py b/models/vocoder/vocoder_dataset.py index f79f3bf..89a3965 100644 --- a/models/vocoder/vocoder_dataset.py +++ b/models/vocoder/vocoder_dataset.py @@ -1,7 +1,7 @@ from torch.utils.data import Dataset from pathlib import Path from models.vocoder.wavernn import audio -import vocoder.wavernn.hparams as hp +import models.vocoder.wavernn.hparams as hp import numpy as np import torch diff --git a/models/vocoder/wavernn/audio.py b/models/vocoder/wavernn/audio.py index bec9768..738a374 100644 --- a/models/vocoder/wavernn/audio.py +++ b/models/vocoder/wavernn/audio.py @@ -1,7 +1,7 @@ import math import numpy as np import librosa -import vocoder.wavernn.hparams as hp +import models.vocoder.wavernn.hparams as hp from scipy.signal import lfilter import soundfile as sf diff --git a/models/vocoder/wavernn/train.py b/models/vocoder/wavernn/train.py index e0fdb1b..d26347e 100644 --- a/models/vocoder/wavernn/train.py +++ b/models/vocoder/wavernn/train.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader from pathlib import Path from torch import optim import torch.nn.functional as F -import vocoder.wavernn.hparams as hp +import models.vocoder.wavernn.hparams as hp import numpy as np import time import torch diff --git a/models/wav2emo/__init__.py b/models/wav2emo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pre.py b/pre.py index dd9a2b5..1723430 100644 --- a/pre.py +++ b/pre.py @@ -16,6 +16,7 @@ recognized_datasets = [ "data_aishell" ] +#TODO: add for emotional data if __name__ == "__main__": parser = argparse.ArgumentParser( description="Preprocesses audio files from datasets, encodes them as mel spectrograms " @@ -42,7 +43,7 @@ if __name__ == "__main__": (these are used to split long audio files into sub-utterances.)") parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\ "Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.") - parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\ + parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="data/ckpt/encoder/pretrained.pt", help=\ "Path your trained encoder model.") parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\ "Number of processes in parallel.An encoder is created for each, so you may need to lower " diff --git a/requirements.txt b/requirements.txt index 119814d..459b0fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -25,4 +25,6 @@ streamlit==1.8.0 PyYAML==5.4.1 torch_complex espnet -PyWavelets \ No newline at end of file +PyWavelets +monotonic-align==0.0.3 +transformers==4.26.0 \ No newline at end of file diff --git a/train.py b/train.py index 557bc26..b5499bb 100644 --- a/train.py +++ b/train.py @@ -5,63 +5,18 @@ import numpy as np from utils.load_yaml import HpsYaml from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver -# For reproducibility, comment these may speed up training -torch.backends.cudnn.deterministic = True -torch.backends.cudnn.benchmark = False - def main(): # Arguments - parser = argparse.ArgumentParser(description= - 'Training PPG2Mel VC model.') - parser.add_argument('--config', type=str, - help='Path to experiment config, e.g., config/vc.yaml') - parser.add_argument('--name', default=None, type=str, help='Name for logging.') - parser.add_argument('--logdir', default='log/', type=str, - help='Logging path.', required=False) - parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str, - help='Checkpoint path.', required=False) - parser.add_argument('--outdir', default='result/', type=str, - help='Decode output path.', required=False) - parser.add_argument('--load', default=None, type=str, - help='Load pre-trained model (for training only)', required=False) - parser.add_argument('--warm_start', action='store_true', - help='Load model weights only, ignore specified layers.') - parser.add_argument('--seed', default=0, type=int, - help='Random seed for reproducable results.', required=False) - parser.add_argument('--njobs', default=8, type=int, - help='Number of threads for dataloader/decoding.', required=False) - parser.add_argument('--cpu', action='store_true', help='Disable GPU training.') - parser.add_argument('--no-pin', action='store_true', - help='Disable pin-memory for dataloader') - parser.add_argument('--test', action='store_true', help='Test the model.') - parser.add_argument('--no-msg', action='store_true', help='Hide all messages.') - parser.add_argument('--finetune', action='store_true', help='Finetune model') - parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model') - parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model') - parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)') + preparser = argparse.ArgumentParser(description= + 'Training model.') + preparser.add_argument('--type', type=str, + help='type of training ') ### - - paras = parser.parse_args() - setattr(paras, 'gpu', not paras.cpu) - setattr(paras, 'pin_memory', not paras.no_pin) - setattr(paras, 'verbose', not paras.no_msg) - # Make the config dict dot visitable - config = HpsYaml(paras.config) - - np.random.seed(paras.seed) - torch.manual_seed(paras.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(paras.seed) - - print(">>> OneShot VC training ...") - mode = "train" - solver = Solver(config, paras, mode) - solver.load_data() - solver.set_model() - solver.exec() - print(">>> Oneshot VC train finished!") - sys.exit(0) + paras, _ = preparser.parse_known_args() + if paras.type == "synth": + from control.cli.synthesizer_train import new_train + new_train() if __name__ == "__main__": main() diff --git a/utils/util.py b/utils/util.py index 34bcffd..3ec9190 100644 --- a/utils/util.py +++ b/utils/util.py @@ -1,4 +1,7 @@ import matplotlib +from torch.nn import functional as F + +import torch matplotlib.use('Agg') import time @@ -48,3 +51,78 @@ class AttrDict(dict): def __init__(self, *args, **kwargs): super(AttrDict, self).__init__(*args, **kwargs) self.__dict__ = self + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + +def get_padding(kernel_size, dilation=1): + return int((kernel_size*dilation - dilation)/2) + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2,3) * mask + return path + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +def intersperse(lst, item): + result = [item] * (len(lst) * 2 + 1) + result[1::2] = lst + return result \ No newline at end of file