mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
dump_ckpt
This commit is contained in:
parent
3d0e18b0ad
commit
b333fcea4b
2
main.py
2
main.py
|
@ -127,6 +127,7 @@ if __name__ == "__main__":
|
|||
'silent_start' : arguments.silent_start,
|
||||
'execute_programs' : [ [int(x[0]), x[1] ] for x in arguments.execute_program ],
|
||||
'debug' : arguments.debug,
|
||||
'dump_ckpt' : arguments.dump_ckpt,
|
||||
}
|
||||
from mainscripts import Trainer
|
||||
Trainer.main(**kwargs)
|
||||
|
@ -144,6 +145,7 @@ if __name__ == "__main__":
|
|||
p.add_argument('--cpu-only', action="store_true", dest="cpu_only", default=False, help="Train on CPU.")
|
||||
p.add_argument('--force-gpu-idxs', dest="force_gpu_idxs", default=None, help="Force to choose GPU indexes separated by comma.")
|
||||
p.add_argument('--silent-start', action="store_true", dest="silent_start", default=False, help="Silent start. Automatically chooses Best GPU and last used model.")
|
||||
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
|
||||
|
||||
|
||||
p.add_argument('--execute-program', dest="execute_program", default=[], action='append', nargs='+')
|
||||
|
|
|
@ -27,6 +27,7 @@ def trainerThread (s2c, c2s, e,
|
|||
silent_start=False,
|
||||
execute_programs = None,
|
||||
debug=False,
|
||||
dump_ckpt=False,
|
||||
**kwargs):
|
||||
while True:
|
||||
try:
|
||||
|
@ -44,7 +45,7 @@ def trainerThread (s2c, c2s, e,
|
|||
saved_models_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
model = models.import_model(model_class_name)(
|
||||
is_training=True,
|
||||
is_training=not dump_ckpt,
|
||||
saved_models_path=saved_models_path,
|
||||
training_data_src_path=training_data_src_path,
|
||||
training_data_dst_path=training_data_dst_path,
|
||||
|
@ -55,9 +56,13 @@ def trainerThread (s2c, c2s, e,
|
|||
force_gpu_idxs=force_gpu_idxs,
|
||||
cpu_only=cpu_only,
|
||||
silent_start=silent_start,
|
||||
debug=debug,
|
||||
)
|
||||
debug=debug)
|
||||
|
||||
if dump_ckpt:
|
||||
e.set()
|
||||
model.dump_ckpt()
|
||||
break
|
||||
|
||||
is_reached_goal = model.is_reached_iter_goal()
|
||||
|
||||
shared_state = { 'after_save' : False }
|
||||
|
|
|
@ -204,6 +204,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
archi_type, archi_opts = archi_split
|
||||
elif len(archi_split) == 1:
|
||||
archi_type, archi_opts = archi_split[0], None
|
||||
|
||||
self.archi_type = archi_type
|
||||
|
||||
ae_dims = self.options['ae_dims']
|
||||
e_dims = self.options['e_dims']
|
||||
|
@ -236,22 +238,22 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
optimizer_vars_on_cpu = models_opt_device=='/CPU:0'
|
||||
|
||||
input_ch=3
|
||||
bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
bgr_shape = self.bgr_shape = nn.get4Dshape(resolution,resolution,input_ch)
|
||||
mask_shape = nn.get4Dshape(resolution,resolution,1)
|
||||
self.model_filename_list = []
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
#Place holders on CPU
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.warped_src = tf.placeholder (nn.floatx, bgr_shape, name='warped_src')
|
||||
self.warped_dst = tf.placeholder (nn.floatx, bgr_shape, name='warped_dst')
|
||||
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape)
|
||||
self.target_src = tf.placeholder (nn.floatx, bgr_shape, name='target_src')
|
||||
self.target_dst = tf.placeholder (nn.floatx, bgr_shape, name='target_dst')
|
||||
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape)
|
||||
self.target_srcm = tf.placeholder (nn.floatx, mask_shape, name='target_srcm')
|
||||
self.target_srcm_em = tf.placeholder (nn.floatx, mask_shape, name='target_srcm_em')
|
||||
self.target_dstm = tf.placeholder (nn.floatx, mask_shape, name='target_dstm')
|
||||
self.target_dstm_em = tf.placeholder (nn.floatx, mask_shape, name='target_dstm_em')
|
||||
|
||||
# Initializing model classes
|
||||
model_archi = nn.DeepFakeArchi(resolution, opts=archi_opts)
|
||||
|
@ -609,7 +611,10 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if do_init:
|
||||
model.init_weights()
|
||||
|
||||
|
||||
|
||||
###############
|
||||
|
||||
# initializing sample generators
|
||||
if self.is_training:
|
||||
training_data_src_path = self.training_data_src_path if not self.pretrain else self.get_pretraining_data_path()
|
||||
|
@ -650,7 +655,44 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
|
|||
|
||||
if self.pretrain_just_disabled:
|
||||
self.update_sample_for_preview(force_new=True)
|
||||
|
||||
def dump_ckpt(self):
|
||||
tf = nn.tf
|
||||
|
||||
|
||||
with tf.device ('/CPU:0'):
|
||||
warped_dst = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
warped_dst = tf.transpose(warped_dst, (0,3,1,2))
|
||||
|
||||
|
||||
if 'df' in self.archi_type:
|
||||
gpu_dst_code = self.inter(self.encoder(warped_dst))
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
|
||||
|
||||
elif 'liae' in self.archi_type:
|
||||
gpu_dst_code = self.encoder (warped_dst)
|
||||
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
|
||||
gpu_dst_inter_AB_code = self.inter_AB (gpu_dst_code)
|
||||
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code], nn.conv2d_ch_axis)
|
||||
|
||||
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
|
||||
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
|
||||
|
||||
gpu_pred_src_dst = tf.transpose(gpu_pred_src_dst, (0,2,3,1))
|
||||
gpu_pred_dst_dstm = tf.transpose(gpu_pred_dst_dstm, (0,2,3,1))
|
||||
gpu_pred_src_dstm = tf.transpose(gpu_pred_src_dstm, (0,2,3,1))
|
||||
|
||||
|
||||
saver = tf.train.Saver()
|
||||
tf.identity(gpu_pred_dst_dstm, name='out_face_mask')
|
||||
tf.identity(gpu_pred_src_dst, name='out_celeb_face')
|
||||
tf.identity(gpu_pred_src_dstm, name='out_celeb_face_mask')
|
||||
|
||||
saver.save(nn.tf_sess, self.get_strpath_storage_for_file('.ckpt') )
|
||||
|
||||
|
||||
#override
|
||||
def get_model_filename_list(self):
|
||||
return self.model_filename_list
|
||||
|
|
Loading…
Reference in New Issue
Block a user