From 8b90ca0dacec05d95f4df8a75b85cbc9e701fe2b Mon Sep 17 00:00:00 2001 From: iperov Date: Wed, 12 May 2021 09:41:53 +0400 Subject: [PATCH] XSegUtil apply xseg now checks model face type --- mainscripts/XSegUtil.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mainscripts/XSegUtil.py b/mainscripts/XSegUtil.py index b95bb2c..c75a14a 100644 --- a/mainscripts/XSegUtil.py +++ b/mainscripts/XSegUtil.py @@ -11,7 +11,7 @@ from core.interact import interact as io from core.leras import nn from DFLIMG import * from facelib import XSegNet, LandmarksProcessor, FaceType - +import pickle def apply_xseg(input_path, model_path): if not input_path.exists(): @@ -19,21 +19,37 @@ def apply_xseg(input_path, model_path): if not model_path.exists(): raise ValueError(f'{model_path} not found. Please ensure it exists.') - - face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower() - if face_type == 'same': - face_type = None - else: + + face_type = None + + model_dat = model_path / 'XSeg_data.dat' + if model_dat.exists(): + dat = pickle.loads( model_dat.read_bytes() ) + dat_options = dat.get('options', None) + if dat_options is not None: + face_type = dat_options.get('face_type', None) + + + + if face_type is None: + face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower() + if face_type == 'same': + face_type = None + + if face_type is not None: face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[face_type] + io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.') device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True) nn.initialize(device_config) + + xseg = XSegNet(name='XSeg', load_weights=True, weights_file_root=model_path,