MockingBird/synthesizer/models/sublayer/pre_net.py

19 lines
535 B
Python
Raw Normal View History

import torch.nn as nn
import torch.nn.functional as F
class PreNet(nn.Module):
def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
super().__init__()
self.fc1 = nn.Linear(in_dims, fc1_dims)
self.fc2 = nn.Linear(fc1_dims, fc2_dims)
self.p = dropout
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = F.dropout(x, self.p, training=True)
x = self.fc2(x)
x = F.relu(x)
x = F.dropout(x, self.p, training=True)
return x