mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
_
This commit is contained in:
parent
24ba84d4a5
commit
5dc027a8b0
|
@ -23,28 +23,13 @@ class Conv2D(nn.LayerBase):
|
|||
if padding == "SAME":
|
||||
padding = ( (kernel_size - 1) * dilations + 1 ) // 2
|
||||
elif padding == "VALID":
|
||||
padding = 0
|
||||
padding = None
|
||||
else:
|
||||
raise ValueError ("Wrong padding type. Should be VALID SAME or INT or 4x INTs")
|
||||
|
||||
if isinstance(padding, int):
|
||||
if padding != 0:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
else:
|
||||
padding = None
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
padding = int(padding)
|
||||
|
||||
|
||||
|
||||
self.in_ch = in_ch
|
||||
self.out_ch = out_ch
|
||||
|
@ -93,9 +78,26 @@ class Conv2D(nn.LayerBase):
|
|||
if self.use_wscale:
|
||||
weight = weight * self.wscale
|
||||
|
||||
if self.padding is not None:
|
||||
x = tf.pad (x, self.padding, mode='CONSTANT')
|
||||
padding = self.padding
|
||||
if padding is not None:
|
||||
if nn.data_format == "NHWC":
|
||||
padding = [ [0,0], [padding,padding], [padding,padding], [0,0] ]
|
||||
else:
|
||||
padding = [ [0,0], [0,0], [padding,padding], [padding,padding] ]
|
||||
x = tf.pad (x, padding, mode='CONSTANT')
|
||||
|
||||
strides = self.strides
|
||||
if nn.data_format == "NHWC":
|
||||
strides = [1,strides,strides,1]
|
||||
else:
|
||||
strides = [1,1,strides,strides]
|
||||
|
||||
dilations = self.dilations
|
||||
if nn.data_format == "NHWC":
|
||||
dilations = [1,dilations,dilations,1]
|
||||
else:
|
||||
dilations = [1,1,dilations,dilations]
|
||||
|
||||
x = tf.nn.conv2d(x, weight, self.strides, 'VALID', dilations=self.dilations, data_format=nn.data_format)
|
||||
if self.use_bias:
|
||||
if nn.data_format == "NHWC":
|
||||
|
|
Loading…
Reference in New Issue
Block a user