diff --git a/mainscripts/Merger.py b/mainscripts/Merger.py index 1739c9f..fba37f1 100644 --- a/mainscripts/Merger.py +++ b/mainscripts/Merger.py @@ -1,4 +1,5 @@ import math +import multiprocessing import traceback from pathlib import Path @@ -13,7 +14,8 @@ from core.joblib import MPClassFuncOnDemand, MPFunc from core.leras import nn from DFLIMG import DFLIMG from facelib import FaceEnhancer, FaceType, LandmarksProcessor, XSegNet -from merger import FrameInfo, MergerConfig, InteractiveMergerSubprocessor +from merger import FrameInfo, InteractiveMergerSubprocessor, MergerConfig + def main (model_class_name=None, saved_models_path=None, @@ -70,6 +72,9 @@ def main (model_class_name=None, if not is_interactive: cfg.ask_settings() + + subprocess_count = io.input_int("Number of workers?", max(8, multiprocessing.cpu_count()), + valid_range=[1, multiprocessing.cpu_count()], help_message="Specify the number of threads to process. A low value may affect performance. A high value may result in memory error. The value may not be greater than CPU cores." ) input_path_image_paths = pathex.get_image_paths(input_path) @@ -199,7 +204,8 @@ def main (model_class_name=None, frames_root_path = input_path, output_path = output_path, output_mask_path = output_mask_path, - model_iter = model.get_iter() + model_iter = model.get_iter(), + subprocess_count = subprocess_count, ).run() model.finalize() diff --git a/merger/InteractiveMergerSubprocessor.py b/merger/InteractiveMergerSubprocessor.py index 47c8a38..58db0c1 100644 --- a/merger/InteractiveMergerSubprocessor.py +++ b/merger/InteractiveMergerSubprocessor.py @@ -140,7 +140,7 @@ class InteractiveMergerSubprocessor(Subprocessor): #override - def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter): + def __init__(self, is_interactive, merger_session_filepath, predictor_func, predictor_input_shape, face_enhancer_func, xseg_256_extract_func, merger_config, frames, frames_root_path, output_path, output_mask_path, model_iter, subprocess_count=4): if len (frames) == 0: raise ValueError ("len (frames) == 0") @@ -161,7 +161,7 @@ class InteractiveMergerSubprocessor(Subprocessor): self.output_mask_path = output_mask_path self.model_iter = model_iter - self.prefetch_frame_count = self.process_count = multiprocessing.cpu_count() + self.prefetch_frame_count = self.process_count = subprocess_count session_data = None if self.is_interactive and self.merger_session_filepath.exists():