diff --git a/core/leras/archis/DeepFakeArchi.py b/core/leras/archis/DeepFakeArchi.py index 5dfd293..316e1f9 100644 --- a/core/leras/archis/DeepFakeArchi.py +++ b/core/leras/archis/DeepFakeArchi.py @@ -20,7 +20,14 @@ class DeepFakeArchi(nn.ArchiBase): conv_dtype = tf.float16 if use_fp16 else tf.float32 - + + if 'c' in opts: + def act(x, alpha=0.1): + return tf.nn.relu(x) + else: + def act(x, alpha=0.1): + return tf.nn.leaky_relu(x, alpha) + if mod is None: class Downscale(nn.ModelBase): def __init__(self, in_ch, out_ch, kernel_size=5, *kwargs ): @@ -34,7 +41,7 @@ class DeepFakeArchi(nn.ArchiBase): def forward(self, x): x = self.conv1(x) - x = tf.nn.leaky_relu(x, 0.1) + x = act(x, 0.1) return x def get_out_ch(self): @@ -62,7 +69,7 @@ class DeepFakeArchi(nn.ArchiBase): def forward(self, x): x = self.conv1(x) - x = tf.nn.leaky_relu(x, 0.1) + x = act(x, 0.1) x = nn.depth_to_space(x, 2) return x @@ -73,9 +80,9 @@ class DeepFakeArchi(nn.ArchiBase): def forward(self, inp): x = self.conv1(inp) - x = tf.nn.leaky_relu(x, 0.2) + x = act(x, 0.2) x = self.conv2(x) - x = tf.nn.leaky_relu(inp + x, 0.2) + x = act(inp + x, 0.2) return x class Encoder(nn.ModelBase): diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 575b318..ae8eb61 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -581,9 +581,6 @@ class AMPModel(ModelBase): generators_count=dst_generators_count ) ]) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] - def export_dfm (self): output_path=self.get_strpath_storage_for_file('model.dfm') @@ -653,26 +650,6 @@ class AMPModel(ModelBase): src_loss, dst_loss = self.train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) - for i in range(bs): - self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i]) ) - self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i]) ) - - if len(self.last_src_samples_loss) >= bs*16: - src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) - dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) - - target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) - target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) - target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) - - target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) - target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) - target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) - - src_loss, dst_loss = self.train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] - if self.gan_power != 0: self.GAN_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 3cc01a1..e0829ec 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -104,7 +104,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if archi_opts is not None: if len(archi_opts) == 0: continue - if len([ 1 for opt in archi_opts if opt not in ['u','d','t'] ]) != 0: + if len([ 1 for opt in archi_opts if opt not in ['u','d','t','c'] ]) != 0: continue if 'd' in archi_opts: @@ -688,9 +688,6 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... generators_count=dst_generators_count ) ]) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] - if self.pretrain_just_disabled: self.update_sample_for_preview(force_new=True) @@ -765,33 +762,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled: io.log_info('You are training the model from scratch. It is strongly recommended to use a pretrained model to speed up the training and improve the quality.\n') - bs = self.get_batch_size() - ( (warped_src, target_src, target_srcm, target_srcm_em), \ (warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples() src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em) - for i in range(bs): - self.last_src_samples_loss.append ( (src_loss[i], target_src[i], target_srcm[i], target_srcm_em[i],) ) - self.last_dst_samples_loss.append ( (dst_loss[i], target_dst[i], target_dstm[i], target_dstm_em[i],) ) - - if len(self.last_src_samples_loss) >= bs*16: - src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(0), reverse=True) - dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(0), reverse=True) - - target_src = np.stack( [ x[1] for x in src_samples_loss[:bs] ] ) - target_srcm = np.stack( [ x[2] for x in src_samples_loss[:bs] ] ) - target_srcm_em = np.stack( [ x[3] for x in src_samples_loss[:bs] ] ) - - target_dst = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] ) - target_dstm = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] ) - target_dstm_em = np.stack( [ x[3] for x in dst_samples_loss[:bs] ] ) - - src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em) - self.last_src_samples_loss = [] - self.last_dst_samples_loss = [] - if self.options['true_face_power'] != 0 and not self.pretrain: self.D_train (warped_src, warped_dst)