mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
init
This commit is contained in:
parent
74a3fc97d0
commit
e469bd06ae
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
|
@ -15,7 +15,8 @@
|
|||
"name": "Python: Vocoder Preprocess",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "vocoder_preprocess.py",
|
||||
"program": "control\\cli\\vocoder_preprocess.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["..\\audiodata"]
|
||||
},
|
||||
|
@ -23,7 +24,8 @@
|
|||
"name": "Python: Vocoder Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "vocoder_train.py",
|
||||
"program": "control\\cli\\vocoder_train.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["dev", "..\\audiodata"]
|
||||
},
|
||||
|
@ -32,6 +34,7 @@
|
|||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata"]
|
||||
},
|
||||
|
@ -40,6 +43,7 @@
|
|||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "demo_toolbox.py",
|
||||
"cwd": "${workspaceFolder}",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["-d","..\\audiodata","-vc"]
|
||||
},
|
||||
|
@ -47,9 +51,9 @@
|
|||
"name": "Python: Synth Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "synthesizer_train.py",
|
||||
"program": "train.py",
|
||||
"console": "integratedTerminal",
|
||||
"args": ["my_run", "..\\"]
|
||||
"args": ["--type", "synth", "..\\audiodata\\SV2TTS\\synthesizer"]
|
||||
},
|
||||
{
|
||||
"name": "Python: PPG Convert",
|
||||
|
|
0
control/cli/__init__.py
Normal file
0
control/cli/__init__.py
Normal file
|
@ -14,7 +14,7 @@ if __name__ == "__main__":
|
|||
"Path to the synthesizer training data that contains the audios and the train.txt file. "
|
||||
"If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.")
|
||||
parser.add_argument("-e", "--encoder_model_fpath", type=Path,
|
||||
default="encoder/saved_models/pretrained.pt", help=\
|
||||
default="data/ckpt/encoder/pretrained.pt", help=\
|
||||
"Path your trained encoder model.")
|
||||
parser.add_argument("-n", "--n_processes", type=int, default=4, help= \
|
||||
"Number of parallel processes. An encoder is created for each, so you may need to lower "
|
||||
|
|
|
@ -3,8 +3,7 @@ from models.synthesizer.train import train
|
|||
from utils.argutils import print_args
|
||||
import argparse
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def new_train():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("run_id", type=str, help= \
|
||||
"Name for this model instance. If a model state from the same run ID was previously "
|
||||
|
@ -13,7 +12,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
|
||||
"Path to the synthesizer directory that contains the ground truth mel spectrograms, "
|
||||
"the wavs and the embeds.")
|
||||
parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\
|
||||
parser.add_argument("-m", "--models_dir", type=str, default=f"data/ckpt/synthesizer/", help=\
|
||||
"Path to the output directory that will contain the saved model weights and the logs.")
|
||||
parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
|
||||
"Number of steps between updates of the model on the disk. Set to 0 to never save the "
|
||||
|
@ -28,10 +27,14 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--hparams", default="",
|
||||
help="Hyperparameter overrides as a comma-separated list of name=value "
|
||||
"pairs")
|
||||
args = parser.parse_args()
|
||||
args, _ = parser.parse_known_args()
|
||||
print_args(args, parser)
|
||||
|
||||
args.hparams = hparams.parse(args.hparams)
|
||||
|
||||
# Run the training
|
||||
train(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
new_train()
|
66
control/cli/train_ppg2mel.py
Normal file
66
control/cli/train_ppg2mel.py
Normal file
|
@ -0,0 +1,66 @@
|
|||
import sys
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
from utils.load_yaml import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
|
||||
###
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
|
@ -46,15 +46,16 @@ else:
|
|||
raise Exception(f"Model folder {VOC_MODELS_DIRT} doesn't exist.")
|
||||
|
||||
|
||||
|
||||
class Input(BaseModel):
|
||||
message: str = Field(
|
||||
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="文本内容"
|
||||
)
|
||||
local_audio_file: audio_input_selection = Field(
|
||||
..., alias="输入语音(本地wav)",
|
||||
..., alias="选择语音(本地wav)",
|
||||
description="选择本地语音文件."
|
||||
)
|
||||
record_audio_file: FileContent = Field(default=None, alias="录制语音",
|
||||
description="录音.", is_recorder=True, mime_type="audio/wav")
|
||||
upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
|
||||
description="拖拽或点击上传.", mime_type="audio/wav")
|
||||
encoder: encoders = Field(
|
||||
|
@ -104,7 +105,12 @@ def synthesize(input: Input) -> Output:
|
|||
gan_vocoder.load_model(Path(input.vocoder.value))
|
||||
|
||||
# load file
|
||||
if input.upload_audio_file != None:
|
||||
if input.record_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.record_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
wav, sample_rate = librosa.load(TEMP_SOURCE_AUDIO)
|
||||
elif input.upload_audio_file != None:
|
||||
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
|
||||
f.write(input.upload_audio_file.as_bytes())
|
||||
f.seek(0)
|
||||
|
|
|
@ -37,6 +37,12 @@ def is_single_file_property(property: Dict) -> bool:
|
|||
# TODO: binary?
|
||||
return property.get("format") == "byte"
|
||||
|
||||
def is_single_autio_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
return False
|
||||
# TODO: binary?
|
||||
return property.get("format") == "bytes"
|
||||
|
||||
|
||||
def is_single_directory_property(property: Dict) -> bool:
|
||||
if property.get("type") != "string":
|
||||
|
|
|
@ -242,7 +242,14 @@ class InputUI:
|
|||
file_extension = None
|
||||
if "mime_type" in property:
|
||||
file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
|
||||
if "is_recorder" in property:
|
||||
from audio_recorder_streamlit import audio_recorder
|
||||
audio_bytes = audio_recorder()
|
||||
if audio_bytes:
|
||||
streamlit_app.audio(audio_bytes, format="audio/wav")
|
||||
return audio_bytes
|
||||
|
||||
uploaded_file = streamlit_app.file_uploader(
|
||||
**streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
)
|
||||
|
@ -262,6 +269,39 @@ class InputUI:
|
|||
streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
return bytes
|
||||
|
||||
def _render_single_audio_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
# streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
|
||||
from audio_recorder_streamlit import audio_recorder
|
||||
audio_bytes = audio_recorder()
|
||||
if audio_bytes:
|
||||
streamlit_app.audio(audio_bytes, format="audio/wav")
|
||||
return audio_bytes
|
||||
|
||||
# file_extension = None
|
||||
# if "mime_type" in property:
|
||||
# file_extension = mimetypes.guess_extension(property["mime_type"])
|
||||
|
||||
# uploaded_file = streamlit_app.file_uploader(
|
||||
# **streamlit_kwargs, accept_multiple_files=False, type=file_extension
|
||||
# )
|
||||
# if uploaded_file is None:
|
||||
# return None
|
||||
|
||||
# bytes = uploaded_file.getvalue()
|
||||
# if property.get("mime_type"):
|
||||
# if is_compatible_audio(property["mime_type"]):
|
||||
# # Show audio
|
||||
# streamlit_app.audio(bytes, format=property.get("mime_type"))
|
||||
# if is_compatible_image(property["mime_type"]):
|
||||
# # Show image
|
||||
# streamlit_app.image(bytes)
|
||||
# if is_compatible_video(property["mime_type"]):
|
||||
# # Show video
|
||||
# streamlit_app.video(bytes, format=property.get("mime_type"))
|
||||
# return bytes
|
||||
|
||||
def _render_single_string_input(
|
||||
self, streamlit_app: st, key: str, property: Dict
|
||||
) -> Any:
|
||||
|
@ -820,7 +860,6 @@ def getOpyrator(mode: str) -> Opyrator:
|
|||
from control.mkgui.app import synthesize
|
||||
return Opyrator(synthesize)
|
||||
|
||||
|
||||
def render_streamlit_ui() -> None:
|
||||
# init
|
||||
session_state = st.session_state
|
||||
|
@ -852,6 +891,13 @@ def render_streamlit_ui() -> None:
|
|||
|
||||
with left:
|
||||
st.header("Control 控制")
|
||||
# if session_state.mode in ["AI拟音", "VC拟音"] :
|
||||
# from audiorecorder import audiorecorder
|
||||
# audio = audiorecorder("Click to record", "Recording...")
|
||||
# if len(audio) > 0:
|
||||
# # To play audio in frontend:
|
||||
# st.audio(audio.tobytes())
|
||||
|
||||
InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(st)
|
||||
execute_selected = st.button(opyrator.action)
|
||||
if execute_selected:
|
||||
|
|
|
@ -38,7 +38,8 @@ recognized_datasets = [
|
|||
"VoxCeleb2/dev/aac",
|
||||
"VoxCeleb2/test/aac",
|
||||
"VCTK-Corpus/wav48",
|
||||
"aidatatang_200zh/corpus",
|
||||
"aidatatang_200zh/corpus/test",
|
||||
"aidatatang_200zh/corpus/train",
|
||||
"aishell3/test/wav",
|
||||
"magicdata/train",
|
||||
]
|
||||
|
|
|
@ -3,7 +3,6 @@ from PyQt5 import QtGui
|
|||
from PyQt5.QtWidgets import *
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
||||
from matplotlib.figure import Figure
|
||||
from models.encoder.inference import plot_embedding_as_heatmap
|
||||
from control.toolbox.utterance import Utterance
|
||||
from pathlib import Path
|
||||
|
|
|
@ -17,15 +17,15 @@ if __name__ == '__main__':
|
|||
"supported datasets.", default=None)
|
||||
parser.add_argument("-vc", "--vc_mode", action="store_true",
|
||||
help="Voice Conversion Mode(PPG based)")
|
||||
parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
|
||||
parser.add_argument("-e", "--enc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}encoder",
|
||||
help="Directory containing saved encoder models")
|
||||
parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
|
||||
parser.add_argument("-s", "--syn_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}synthesizer",
|
||||
help="Directory containing saved synthesizer models")
|
||||
parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
|
||||
parser.add_argument("-v", "--voc_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}vocoder",
|
||||
help="Directory containing saved vocoder models")
|
||||
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default="ppg_extractor/saved_models",
|
||||
parser.add_argument("-ex", "--extractor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg_extractor",
|
||||
help="Directory containing saved extrator models")
|
||||
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default="ppg2mel/saved_models",
|
||||
parser.add_argument("-cv", "--convertor_models_dir", type=Path, default=f"data{os.sep}ckpt{os.sep}ppg2mel",
|
||||
help="Directory containing saved convert models")
|
||||
parser.add_argument("--cpu", action="store_true", help=\
|
||||
"If True, processing is done on CPU, even when a GPU is available.")
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
|||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from models.vocoder.fregan.utils import init_weights, get_padding
|
||||
from utils.util import init_weights, get_padding
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
|
|
@ -27,21 +27,12 @@ def plot_spectrogram(spectrogram):
|
|||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
weight_norm(m)
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.nn.functional as F
|
|||
import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from models.vocoder.hifigan.utils import init_weights, get_padding
|
||||
from utils.util import init_weights, get_padding
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ from torch.nn.utils import weight_norm
|
|||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
||||
|
@ -19,12 +18,6 @@ def plot_spectrogram(spectrogram):
|
|||
return fig
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def apply_weight_norm(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
|
@ -55,4 +48,3 @@ def scan_checkpoint(cp_dir, prefix):
|
|||
if len(cp_list) == 0:
|
||||
return None
|
||||
return sorted(cp_list)[-1]
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from torch.utils.data import Dataset
|
||||
from pathlib import Path
|
||||
from models.vocoder.wavernn import audio
|
||||
import vocoder.wavernn.hparams as hp
|
||||
import models.vocoder.wavernn.hparams as hp
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
import numpy as np
|
||||
import librosa
|
||||
import vocoder.wavernn.hparams as hp
|
||||
import models.vocoder.wavernn.hparams as hp
|
||||
from scipy.signal import lfilter
|
||||
import soundfile as sf
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
|
|||
from pathlib import Path
|
||||
from torch import optim
|
||||
import torch.nn.functional as F
|
||||
import vocoder.wavernn.hparams as hp
|
||||
import models.vocoder.wavernn.hparams as hp
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
|
|
0
models/wav2emo/__init__.py
Normal file
0
models/wav2emo/__init__.py
Normal file
3
pre.py
3
pre.py
|
@ -16,6 +16,7 @@ recognized_datasets = [
|
|||
"data_aishell"
|
||||
]
|
||||
|
||||
#TODO: add for emotional data
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
|
||||
|
@ -42,7 +43,7 @@ if __name__ == "__main__":
|
|||
(these are used to split long audio files into sub-utterances.)")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="aidatatang_200zh", help=\
|
||||
"Name of the dataset to process, allowing values: magicdata, aidatatang_200zh, aishell3, data_aishell.")
|
||||
parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="encoder/saved_models/pretrained.pt", help=\
|
||||
parser.add_argument("-e", "--encoder_model_fpath", type=Path, default="data/ckpt/encoder/pretrained.pt", help=\
|
||||
"Path your trained encoder model.")
|
||||
parser.add_argument("-ne", "--n_processes_embed", type=int, default=1, help=\
|
||||
"Number of processes in parallel.An encoder is created for each, so you may need to lower "
|
||||
|
|
|
@ -25,4 +25,6 @@ streamlit==1.8.0
|
|||
PyYAML==5.4.1
|
||||
torch_complex
|
||||
espnet
|
||||
PyWavelets
|
||||
PyWavelets
|
||||
monotonic-align==0.0.3
|
||||
transformers==4.26.0
|
61
train.py
61
train.py
|
@ -5,63 +5,18 @@ import numpy as np
|
|||
from utils.load_yaml import HpsYaml
|
||||
from models.ppg2mel.train.train_linglf02mel_seq2seq_oneshotvc import Solver
|
||||
|
||||
# For reproducibility, comment these may speed up training
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
def main():
|
||||
# Arguments
|
||||
parser = argparse.ArgumentParser(description=
|
||||
'Training PPG2Mel VC model.')
|
||||
parser.add_argument('--config', type=str,
|
||||
help='Path to experiment config, e.g., config/vc.yaml')
|
||||
parser.add_argument('--name', default=None, type=str, help='Name for logging.')
|
||||
parser.add_argument('--logdir', default='log/', type=str,
|
||||
help='Logging path.', required=False)
|
||||
parser.add_argument('--ckpdir', default='ppg2mel/saved_models/', type=str,
|
||||
help='Checkpoint path.', required=False)
|
||||
parser.add_argument('--outdir', default='result/', type=str,
|
||||
help='Decode output path.', required=False)
|
||||
parser.add_argument('--load', default=None, type=str,
|
||||
help='Load pre-trained model (for training only)', required=False)
|
||||
parser.add_argument('--warm_start', action='store_true',
|
||||
help='Load model weights only, ignore specified layers.')
|
||||
parser.add_argument('--seed', default=0, type=int,
|
||||
help='Random seed for reproducable results.', required=False)
|
||||
parser.add_argument('--njobs', default=8, type=int,
|
||||
help='Number of threads for dataloader/decoding.', required=False)
|
||||
parser.add_argument('--cpu', action='store_true', help='Disable GPU training.')
|
||||
parser.add_argument('--no-pin', action='store_true',
|
||||
help='Disable pin-memory for dataloader')
|
||||
parser.add_argument('--test', action='store_true', help='Test the model.')
|
||||
parser.add_argument('--no-msg', action='store_true', help='Hide all messages.')
|
||||
parser.add_argument('--finetune', action='store_true', help='Finetune model')
|
||||
parser.add_argument('--oneshotvc', action='store_true', help='Oneshot VC model')
|
||||
parser.add_argument('--bilstm', action='store_true', help='BiLSTM VC model')
|
||||
parser.add_argument('--lsa', action='store_true', help='Use location-sensitive attention (LSA)')
|
||||
preparser = argparse.ArgumentParser(description=
|
||||
'Training model.')
|
||||
preparser.add_argument('--type', type=str,
|
||||
help='type of training ')
|
||||
|
||||
###
|
||||
|
||||
paras = parser.parse_args()
|
||||
setattr(paras, 'gpu', not paras.cpu)
|
||||
setattr(paras, 'pin_memory', not paras.no_pin)
|
||||
setattr(paras, 'verbose', not paras.no_msg)
|
||||
# Make the config dict dot visitable
|
||||
config = HpsYaml(paras.config)
|
||||
|
||||
np.random.seed(paras.seed)
|
||||
torch.manual_seed(paras.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(paras.seed)
|
||||
|
||||
print(">>> OneShot VC training ...")
|
||||
mode = "train"
|
||||
solver = Solver(config, paras, mode)
|
||||
solver.load_data()
|
||||
solver.set_model()
|
||||
solver.exec()
|
||||
print(">>> Oneshot VC train finished!")
|
||||
sys.exit(0)
|
||||
paras, _ = preparser.parse_known_args()
|
||||
if paras.type == "synth":
|
||||
from control.cli.synthesizer_train import new_train
|
||||
new_train()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
import matplotlib
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch
|
||||
matplotlib.use('Agg')
|
||||
import time
|
||||
|
||||
|
@ -48,3 +51,78 @@ class AttrDict(dict):
|
|||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size*dilation - dilation)/2)
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
max_length = length.max()
|
||||
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
||||
return x.unsqueeze(0) < length.unsqueeze(1)
|
||||
|
||||
def slice_segments(x, ids_str, segment_size=4):
|
||||
ret = torch.zeros_like(x[:, :, :segment_size])
|
||||
for i in range(x.size(0)):
|
||||
idx_str = ids_str[i]
|
||||
idx_end = idx_str + segment_size
|
||||
ret[i] = x[i, :, idx_str:idx_end]
|
||||
return ret
|
||||
|
||||
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
b, d, t = x.size()
|
||||
if x_lengths is None:
|
||||
x_lengths = t
|
||||
ids_str_max = x_lengths - segment_size + 1
|
||||
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
||||
ret = slice_segments(x, ids_str, segment_size)
|
||||
return ret, ids_str
|
||||
|
||||
|
||||
def convert_pad_shape(pad_shape):
|
||||
l = pad_shape[::-1]
|
||||
pad_shape = [item for sublist in l for item in sublist]
|
||||
return pad_shape
|
||||
|
||||
def generate_path(duration, mask):
|
||||
"""
|
||||
duration: [b, 1, t_x]
|
||||
mask: [b, 1, t_y, t_x]
|
||||
"""
|
||||
device = duration.device
|
||||
|
||||
b, _, t_y, t_x = mask.shape
|
||||
cum_duration = torch.cumsum(duration, -1)
|
||||
|
||||
cum_duration_flat = cum_duration.view(b * t_x)
|
||||
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
||||
path = path.view(b, t_x, t_y)
|
||||
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
||||
path = path.unsqueeze(1).transpose(2,3) * mask
|
||||
return path
|
||||
|
||||
@torch.jit.script
|
||||
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
||||
n_channels_int = n_channels[0]
|
||||
in_act = input_a + input_b
|
||||
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
||||
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
||||
acts = t_act * s_act
|
||||
return acts
|
||||
|
||||
|
||||
def subsequent_mask(length):
|
||||
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
||||
return mask
|
||||
|
||||
|
||||
def intersperse(lst, item):
|
||||
result = [item] * (len(lst) * 2 + 1)
|
||||
result[1::2] = lst
|
||||
return result
|
Loading…
Reference in New Issue
Block a user