mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
_
This commit is contained in:
parent
8a897f236f
commit
33ff0be722
|
@ -8,6 +8,7 @@ from .Dense import *
|
|||
from .BlurPool import *
|
||||
|
||||
from .BatchNorm2D import *
|
||||
from .InstanceNorm2D import *
|
||||
from .FRNorm2D import *
|
||||
|
||||
from .TLU import *
|
||||
|
|
|
@ -108,10 +108,15 @@ nn.gelu = gelu
|
|||
|
||||
def upsample2d(x, size=2):
|
||||
if nn.data_format == "NCHW":
|
||||
b,c,h,w = x.shape.as_list()
|
||||
x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
x = tf.transpose(x, (0,2,3,1))
|
||||
x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
x = tf.transpose(x, (0,3,1,2))
|
||||
|
||||
|
||||
# b,c,h,w = x.shape.as_list()
|
||||
# x = tf.reshape (x, (-1,c,h,1,w,1) )
|
||||
# x = tf.tile(x, (1,1,1,size,1,size) )
|
||||
# x = tf.reshape (x, (-1,c,h*size,w*size) )
|
||||
return x
|
||||
else:
|
||||
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
|
||||
|
|
Loading…
Reference in New Issue
Block a user