This commit is contained in:
iperov 2021-09-29 16:41:43 +04:00
parent 8a897f236f
commit 33ff0be722
2 changed files with 10 additions and 4 deletions

View File

@ -8,6 +8,7 @@ from .Dense import *
from .BlurPool import *
from .BatchNorm2D import *
from .InstanceNorm2D import *
from .FRNorm2D import *
from .TLU import *

View File

@ -108,10 +108,15 @@ nn.gelu = gelu
def upsample2d(x, size=2):
if nn.data_format == "NCHW":
b,c,h,w = x.shape.as_list()
x = tf.reshape (x, (-1,c,h,1,w,1) )
x = tf.tile(x, (1,1,1,size,1,size) )
x = tf.reshape (x, (-1,c,h*size,w*size) )
x = tf.transpose(x, (0,2,3,1))
x = tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )
x = tf.transpose(x, (0,3,1,2))
# b,c,h,w = x.shape.as_list()
# x = tf.reshape (x, (-1,c,h,1,w,1) )
# x = tf.tile(x, (1,1,1,size,1,size) )
# x = tf.reshape (x, (-1,c,h*size,w*size) )
return x
else:
return tf.image.resize_nearest_neighbor(x, (x.shape[1]*size, x.shape[2]*size) )