SAEHD, AMP: removed the implicit function of periodically retraining last 16 “high-loss” samples

This commit is contained in:
iperov 2021-09-29 16:48:54 +04:00
parent 33ff0be722
commit 9e0079c6a0
3 changed files with 13 additions and 54 deletions

View File

@ -21,6 +21,13 @@ class DeepFakeArchi(nn.ArchiBase):
conv_dtype = tf.float16 if use_fp16 else tf.float32
if 'c' in opts:
def act(x, alpha=0.1):
return tf.nn.relu(x)
else:
def act(x, alpha=0.1):
return tf.nn.leaky_relu(x, alpha)
if mod is None:
class Downscale(nn.ModelBase):
def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ):
@ -34,7 +41,7 @@ class DeepFakeArchi(nn.ArchiBase):
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1)
x = act(x, 0.1)
return x
def get_out_ch(self):
@ -62,7 +69,7 @@ class DeepFakeArchi(nn.ArchiBase):
def forward(self, x):
x = self.conv1(x)
x = tf.nn.leaky_relu(x, 0.1)
x = act(x, 0.1)
x = nn.depth_to_space(x, 2)
return x
@ -73,9 +80,9 @@ class DeepFakeArchi(nn.ArchiBase):
def forward(self, inp):
x = self.conv1(inp)
x = tf.nn.leaky_relu(x, 0.2)
x = act(x, 0.2)
x = self.conv2(x)
x = tf.nn.leaky_relu(inp + x, 0.2)
x = act(inp + x, 0.2)
return x
class Encoder(nn.ModelBase):

View File

@ -581,9 +581,6 @@ class AMPModel(ModelBase):
generators_count=dst_generators_count )
])
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
def export_dfm (self):
output_path=self.get_strpath_storage_for_file('model.dfm')
@ -653,26 +650,6 @@ class AMPModel(ModelBase):
src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs):
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) )
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.gan_power != 0:
self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)

View File

@ -104,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if archi_opts is not None:
if len(archi_opts) == 0:
continue
if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0:
if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0:
continue
if 'd' in archi_opts:
@ -688,9 +688,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
generators_count=dst_generators_count )
])
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.pretrain_just_disabled:
self.update_sample_for_preview(force_new=True)
@ -765,33 +762,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n')
bs = self.get_batch_size()
( (warped_src, target_src, target_srcm, target_srcm_em), \
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
for i in range(bs):
self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) )
self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) )
if len(self.last_src_samples_loss) >= bs*16:
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True)
target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] )
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
if self.options['true_face_power'] != 0 and not self.pretrain:
self.D_train (warped_src, warped_dst)