mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
XSeg: changed pretrain mode
This commit is contained in:
parent
623eb3856d
commit
c5584fbda0
|
@ -34,7 +34,7 @@ class XSegModel(ModelBase):
|
|||
self.ask_batch_size(4, range=[2,16])
|
||||
self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain)
|
||||
|
||||
if self.options['pretrain'] and self.get_pretraining_data_path() is None:
|
||||
if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None):
|
||||
raise Exception("pretraining_data_path is not defined")
|
||||
|
||||
self.pretrain_just_disabled = (default_pretrain == True and self.options['pretrain'] == False)
|
||||
|
@ -42,7 +42,7 @@ class XSegModel(ModelBase):
|
|||
#override
|
||||
def on_initialize(self):
|
||||
device_config = nn.getCurrentDeviceConfig()
|
||||
self.model_data_format = "NCHW" if len(device_config.devices) != 0 and not self.is_debug() else "NHWC"
|
||||
self.model_data_format = "NCHW" if self.is_exporting or (len(device_config.devices) != 0 and not self.is_debug()) else "NHWC"
|
||||
nn.initialize(data_format=self.model_data_format)
|
||||
tf = nn.tf
|
||||
|
||||
|
@ -85,8 +85,6 @@ class XSegModel(ModelBase):
|
|||
bs_per_gpu = max(1, self.get_batch_size() // gpu_count)
|
||||
self.set_batch_size( gpu_count*bs_per_gpu)
|
||||
|
||||
targetm_t = tf.placeholder (nn.floatx, mask_shape)
|
||||
|
||||
# Compute losses per GPU
|
||||
gpu_pred_list = []
|
||||
|
||||
|
@ -100,7 +98,6 @@ class XSegModel(ModelBase):
|
|||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||
gpu_input_t = self.model.input_t [batch_slice,:,:,:]
|
||||
gpu_target_t = self.model.target_t [batch_slice,:,:,:]
|
||||
gpu_targetm_t = targetm_t [batch_slice,:,:,:]
|
||||
|
||||
# process model tensors
|
||||
gpu_pred_logits_t, gpu_pred_t = self.model.flow(gpu_input_t, pretrain=self.pretrain)
|
||||
|
@ -108,17 +105,11 @@ class XSegModel(ModelBase):
|
|||
|
||||
|
||||
if self.pretrain:
|
||||
gpu_targetm_blur = nn.gaussian_blur(gpu_targetm_t, max(1, resolution // 32) )
|
||||
gpu_targetm_blur = tf.clip_by_value(gpu_targetm_blur, 0, 0.5) * 2
|
||||
|
||||
gpu_target_t_blur = gpu_target_t*gpu_targetm_blur
|
||||
gpu_pred_t_blur = gpu_pred_t*gpu_targetm_t
|
||||
|
||||
# Structural loss
|
||||
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t_blur, gpu_pred_t_blur, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
||||
gpu_loss = tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
|
||||
gpu_loss += tf.reduce_mean (5*nn.dssim(gpu_target_t, gpu_pred_t, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
|
||||
# Pixel loss
|
||||
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t_blur-gpu_pred_t_blur), axis=[1,2,3])
|
||||
gpu_loss += tf.reduce_mean (10*tf.square(gpu_target_t-gpu_pred_t), axis=[1,2,3])
|
||||
else:
|
||||
gpu_loss = tf.reduce_mean( tf.nn.sigmoid_cross_entropy_with_logits(labels=gpu_target_t, logits=gpu_pred_logits_t), axis=[1,2,3])
|
||||
|
||||
|
@ -137,8 +128,8 @@ class XSegModel(ModelBase):
|
|||
|
||||
# Initializing training and view functions
|
||||
if self.pretrain:
|
||||
def train(input_np, target_np, targetm_np):
|
||||
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np, targetm_t :targetm_np })
|
||||
def train(input_np, target_np):
|
||||
l, _ = nn.tf_sess.run ( [loss, loss_gv_op], feed_dict={self.model.input_t :input_np, self.model.target_t :target_np})
|
||||
return l
|
||||
else:
|
||||
def train(input_np, target_np):
|
||||
|
@ -160,8 +151,7 @@ class XSegModel(ModelBase):
|
|||
pretrain_gen = SampleGeneratorFace(self.get_pretraining_data_path(), debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.FULL_FACE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
{'sample_type': SampleProcessor.SampleType.FACE_IMAGE,'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||
],
|
||||
uniform_yaw_distribution=False,
|
||||
generators_count=cpu_count )
|
||||
|
@ -200,13 +190,9 @@ class XSegModel(ModelBase):
|
|||
|
||||
#override
|
||||
def onTrainOneIter(self):
|
||||
if self.pretrain:
|
||||
image_np, target_np, targetm_np = self.generate_next_samples()[0]
|
||||
loss = self.train (image_np, target_np, targetm_np)
|
||||
else:
|
||||
image_np, mask_np = self.generate_next_samples()[0]
|
||||
loss = self.train (image_np, mask_np)
|
||||
|
||||
image_np, target_np = self.generate_next_samples()[0]
|
||||
loss = self.train (image_np, target_np)
|
||||
|
||||
return ( ('loss', np.mean(loss) ), )
|
||||
|
||||
#override
|
||||
|
@ -215,7 +201,7 @@ class XSegModel(ModelBase):
|
|||
|
||||
if self.pretrain:
|
||||
srcdst_samples, = samples
|
||||
image_np, mask_np, _ = srcdst_samples
|
||||
image_np, mask_np = srcdst_samples
|
||||
else:
|
||||
srcdst_samples, src_samples, dst_samples = samples
|
||||
image_np, mask_np = srcdst_samples
|
||||
|
@ -264,5 +250,34 @@ class XSegModel(ModelBase):
|
|||
result += [ ('XSeg dst faces', np.concatenate (st, axis=0 )), ]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def export_dfm (self):
|
||||
output_path = self.get_strpath_storage_for_file(f'model.onnx')
|
||||
io.log_info(f'Dumping .onnx to {output_path}')
|
||||
tf = nn.tf
|
||||
|
||||
with tf.device (nn.tf_default_device_name):
|
||||
input_t = tf.placeholder (nn.floatx, (None, self.resolution, self.resolution, 3), name='in_face')
|
||||
input_t = tf.transpose(input_t, (0,3,1,2))
|
||||
_, pred_t = self.model.flow(input_t)
|
||||
pred_t = tf.transpose(pred_t, (0,2,3,1))
|
||||
|
||||
tf.identity(pred_t, name='out_mask')
|
||||
|
||||
output_graph_def = tf.graph_util.convert_variables_to_constants(
|
||||
nn.tf_sess,
|
||||
tf.get_default_graph().as_graph_def(),
|
||||
['out_mask']
|
||||
)
|
||||
|
||||
import tf2onnx
|
||||
with tf.device("/CPU:0"):
|
||||
model_proto, _ = tf2onnx.convert._convert_common(
|
||||
output_graph_def,
|
||||
name='XSeg',
|
||||
input_names=['in_face:0'],
|
||||
output_names=['out_mask:0'],
|
||||
opset=13,
|
||||
output_path=output_path)
|
||||
|
||||
Model = XSegModel
|
Loading…
Reference in New Issue
Block a user