mirror of
https://github.com/babysor/MockingBird.git
synced 2024-03-22 13:11:31 +08:00
34 lines
874 B
Python
34 lines
874 B
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
# Copyright 2019 Shigeki Karita
|
|
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
|
|
|
"""Layer normalization module."""
|
|
|
|
import torch
|
|
|
|
|
|
class LayerNorm(torch.nn.LayerNorm):
|
|
"""Layer normalization module.
|
|
|
|
:param int nout: output dim size
|
|
:param int dim: dimension to be normalized
|
|
"""
|
|
|
|
def __init__(self, nout, dim=-1):
|
|
"""Construct an LayerNorm object."""
|
|
super(LayerNorm, self).__init__(nout, eps=1e-12)
|
|
self.dim = dim
|
|
|
|
def forward(self, x):
|
|
"""Apply layer normalization.
|
|
|
|
:param torch.Tensor x: input tensor
|
|
:return: layer normalized tensor
|
|
:rtype torch.Tensor
|
|
"""
|
|
if self.dim == -1:
|
|
return super(LayerNorm, self).forward(x)
|
|
return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1)
|