mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
103 lines
3.0 KiB
Python
103 lines
3.0 KiB
Python
|
import argparse
|
||
|
import torch
|
||
|
from pathlib import Path
|
||
|
import yaml
|
||
|
|
||
|
from .frontend import DefaultFrontend
|
||
|
from .utterance_mvn import UtteranceMVN
|
||
|
from .encoder.conformer_encoder import ConformerEncoder
|
||
|
|
||
|
_model = None # type: PPGModel
|
||
|
_device = None
|
||
|
|
||
|
class PPGModel(torch.nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
frontend,
|
||
|
normalizer,
|
||
|
encoder,
|
||
|
):
|
||
|
super().__init__()
|
||
|
self.frontend = frontend
|
||
|
self.normalize = normalizer
|
||
|
self.encoder = encoder
|
||
|
|
||
|
def forward(self, speech, speech_lengths):
|
||
|
"""
|
||
|
|
||
|
Args:
|
||
|
speech (tensor): (B, L)
|
||
|
speech_lengths (tensor): (B, )
|
||
|
|
||
|
Returns:
|
||
|
bottle_neck_feats (tensor): (B, L//hop_size, 144)
|
||
|
|
||
|
"""
|
||
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||
|
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||
|
encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
|
||
|
return encoder_out
|
||
|
|
||
|
def _extract_feats(
|
||
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||
|
):
|
||
|
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||
|
|
||
|
# for data-parallel
|
||
|
speech = speech[:, : speech_lengths.max()]
|
||
|
|
||
|
if self.frontend is not None:
|
||
|
# Frontend
|
||
|
# e.g. STFT and Feature extract
|
||
|
# data_loader may send time-domain signal in this case
|
||
|
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||
|
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||
|
else:
|
||
|
# No frontend and no feature extract
|
||
|
feats, feats_lengths = speech, speech_lengths
|
||
|
return feats, feats_lengths
|
||
|
|
||
|
def extract_from_wav(self, src_wav):
|
||
|
src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device)
|
||
|
src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device)
|
||
|
return self(src_wav_tensor, src_wav_lengths)
|
||
|
|
||
|
|
||
|
def build_model(args):
|
||
|
normalizer = UtteranceMVN(**args.normalize_conf)
|
||
|
frontend = DefaultFrontend(**args.frontend_conf)
|
||
|
encoder = ConformerEncoder(input_size=80, **args.encoder_conf)
|
||
|
model = PPGModel(frontend, normalizer, encoder)
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
def load_model(model_file, device=None):
|
||
|
global _model, _device
|
||
|
|
||
|
if device is None:
|
||
|
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||
|
else:
|
||
|
_device = device
|
||
|
# search a config file
|
||
|
model_config_fpaths = list(model_file.parent.rglob("*.yaml"))
|
||
|
config_file = model_config_fpaths[0]
|
||
|
with config_file.open("r", encoding="utf-8") as f:
|
||
|
args = yaml.safe_load(f)
|
||
|
|
||
|
args = argparse.Namespace(**args)
|
||
|
|
||
|
model = build_model(args)
|
||
|
model_state_dict = model.state_dict()
|
||
|
|
||
|
ckpt_state_dict = torch.load(model_file, map_location=_device)
|
||
|
ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k}
|
||
|
|
||
|
model_state_dict.update(ckpt_state_dict)
|
||
|
model.load_state_dict(model_state_dict)
|
||
|
|
||
|
_model = model.eval().to(_device)
|
||
|
return _model
|
||
|
|
||
|
|