mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
nothing interesting
This commit is contained in:
parent
4c2cb44643
commit
cbc18b2d41
|
@ -43,7 +43,7 @@ class PoseEstimator(object):
|
|||
mean_t, logvar_t = input
|
||||
return mean_t + K.exp(0.5*logvar_t)*K.random_normal(K.shape(mean_t))
|
||||
|
||||
self.BVAEResampler = Lambda ( lambda x: x[0] + K.exp(0.5*x[1])*K.random_normal(K.shape(x[0])),
|
||||
self.BVAEResampler = Lambda ( lambda x: x[0] + K.random_normal(K.shape(x[0])) * K.sqrt(K.exp(0.5*x[1])),
|
||||
output_shape=K.int_shape(self.encoder.outputs[0])[1:] )
|
||||
|
||||
inp_t = Input (self.input_bgr_shape)
|
||||
|
@ -99,24 +99,21 @@ class PoseEstimator(object):
|
|||
pyr_loss += [ a*K.mean( K.square ( inp_pyrs_t[i] - pyrs_t[i]) ) ]
|
||||
|
||||
def BVAELoss(beta=4):
|
||||
#keep in mind loss per sample, not per minibatch
|
||||
def func(input):
|
||||
mean_t, logvar_t = input
|
||||
return beta * K.mean ( K.sum( -0.5*(1 + logvar_t - K.exp(logvar_t) - K.square(mean_t)), axis=1 ), axis=0, keepdims=True )
|
||||
return beta * K.mean ( K.sum( 0.5*(K.exp(logvar_t)+ K.square(mean_t)-logvar_t-1), axis=1) )
|
||||
return func
|
||||
|
||||
BVAE_loss = BVAELoss(4)([mean_t, logvar_t])#beta * K.mean ( K.sum( -0.5*(1 + logvar_t - K.exp(logvar_t) - K.square(mean_t)), axis=1 ), axis=0, keepdims=True )
|
||||
|
||||
|
||||
bgr_loss = K.mean(K.square(inp_real_t-bgr_t), axis=0, keepdims=True)
|
||||
|
||||
#train_loss = BVAE_loss + bgr_loss
|
||||
BVAE_loss = BVAELoss()([mean_t, logvar_t])
|
||||
|
||||
bgr_loss = K.mean(K.sum(K.abs(inp_real_t-bgr_t), axis=[1,2,3]))
|
||||
|
||||
G_loss = BVAE_loss+bgr_loss
|
||||
pyr_loss = sum(pyr_loss)
|
||||
|
||||
|
||||
self.train = K.function ([inp_t, inp_real_t],
|
||||
[ K.mean (BVAE_loss)+K.mean(bgr_loss) ], Adam(lr=0.0005, beta_1=0.9, beta_2=0.999).get_updates( [BVAE_loss, bgr_loss], self.encoder.trainable_weights+self.decoder.trainable_weights ) )
|
||||
[ G_loss ], Adam(lr=0.0005, beta_1=0.9, beta_2=0.999).get_updates( G_loss, self.encoder.trainable_weights+self.decoder.trainable_weights ) )
|
||||
|
||||
self.train_l = K.function ([inp_t] + inp_pyrs_t,
|
||||
[pyr_loss], Adam(lr=0.0001).get_updates( pyr_loss, self.model_l.trainable_weights) )
|
||||
|
@ -140,7 +137,6 @@ class PoseEstimator(object):
|
|||
Model(inp_t, self.model_l(self.BVAEResampler(self.encoder(inp_t))) ).save_weights (str(self.model_weights_path))
|
||||
|
||||
def train_on_batch(self, warps, imgs, pyr_tanh, skip_bgr_train=False):
|
||||
|
||||
if not skip_bgr_train:
|
||||
bgr_loss, = self.train( [warps, imgs] )
|
||||
pyr_loss = 0
|
||||
|
@ -198,12 +194,9 @@ class PoseEstimator(object):
|
|||
def EncFlow(ae_dims):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
XConv2D = partial(Conv2D, padding='zero')
|
||||
|
||||
|
||||
def downscale (dim, **kwargs):
|
||||
def func(x):
|
||||
return ReLU() ( ( XConv2D(dim, kernel_size=4, strides=2)(x)) )
|
||||
return ReLU() ( Conv2D(dim, kernel_size=5, strides=2, padding='same')(x))
|
||||
return func
|
||||
|
||||
|
||||
|
@ -236,16 +229,14 @@ class PoseEstimator(object):
|
|||
def DecFlow(resolution, ae_dims):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
XConv2D = partial(Conv2D, padding='zero')
|
||||
|
||||
def upscale (dim, strides=2, **kwargs):
|
||||
def func(x):
|
||||
return ReLU()( ( Conv2DTranspose(dim, kernel_size=4, strides=strides, padding='same')(x)) )
|
||||
return ReLU()( ( Conv2DTranspose(dim, kernel_size=3, strides=strides, padding='same')(x)) )
|
||||
return func
|
||||
|
||||
def to_bgr (output_nc, **kwargs):
|
||||
def func(x):
|
||||
return XConv2D(output_nc, kernel_size=5, activation='sigmoid')(x)
|
||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='sigmoid')(x)
|
||||
return func
|
||||
|
||||
upscale = partial(upscale)
|
||||
|
@ -278,8 +269,6 @@ class PoseEstimator(object):
|
|||
def LatentFlow(class_nums):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
XConv2D = partial(Conv2D, padding='zero')
|
||||
|
||||
def func(latent):
|
||||
x = latent
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user