import argparse import torch from pathlib import Path import yaml from .frontend import DefaultFrontend from .utterance_mvn import UtteranceMVN from .encoder.conformer_encoder import ConformerEncoder _model = None # type: PPGModel _device = None class PPGModel(torch.nn.Module): def __init__( self, frontend, normalizer, encoder, ): super().__init__() self.frontend = frontend self.normalize = normalizer self.encoder = encoder def forward(self, speech, speech_lengths): """ Args: speech (tensor): (B, L) speech_lengths (tensor): (B, ) Returns: bottle_neck_feats (tensor): (B, L//hop_size, 144) """ feats, feats_lengths = self._extract_feats(speech, speech_lengths) feats, feats_lengths = self.normalize(feats, feats_lengths) encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) return encoder_out def _extract_feats( self, speech: torch.Tensor, speech_lengths: torch.Tensor ): assert speech_lengths.dim() == 1, speech_lengths.shape # for data-parallel speech = speech[:, : speech_lengths.max()] if self.frontend is not None: # Frontend # e.g. STFT and Feature extract # data_loader may send time-domain signal in this case # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) feats, feats_lengths = self.frontend(speech, speech_lengths) else: # No frontend and no feature extract feats, feats_lengths = speech, speech_lengths return feats, feats_lengths def extract_from_wav(self, src_wav): src_wav_tensor = torch.from_numpy(src_wav).unsqueeze(0).float().to(_device) src_wav_lengths = torch.LongTensor([len(src_wav)]).to(_device) return self(src_wav_tensor, src_wav_lengths) def build_model(args): normalizer = UtteranceMVN(**args.normalize_conf) frontend = DefaultFrontend(**args.frontend_conf) encoder = ConformerEncoder(input_size=80, **args.encoder_conf) model = PPGModel(frontend, normalizer, encoder) return model def load_model(model_file, device=None): global _model, _device if device is None: _device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: _device = device # search a config file model_config_fpaths = list(model_file.parent.rglob("*.yaml")) config_file = model_config_fpaths[0] with config_file.open("r", encoding="utf-8") as f: args = yaml.safe_load(f) args = argparse.Namespace(**args) model = build_model(args) model_state_dict = model.state_dict() ckpt_state_dict = torch.load(model_file, map_location=_device) ckpt_state_dict = {k:v for k,v in ckpt_state_dict.items() if 'encoder' in k} model_state_dict.update(ckpt_state_dict) model.load_state_dict(model_state_dict) _model = model.eval().to(_device) return _model