mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2024-03-22 13:10:55 +08:00
fix for amd
This commit is contained in:
parent
0448a461f5
commit
75eeef0a96
|
@ -7,7 +7,7 @@ import numpy as np
|
|||
from nnlib import nnlib
|
||||
|
||||
class S3FDExtractor(object):
|
||||
def __init__(self):
|
||||
def __init__(self, do_dummy_predict=False):
|
||||
exec( nnlib.import_all(), locals(), globals() )
|
||||
|
||||
model_path = Path(__file__).parent / "S3FD.h5"
|
||||
|
@ -16,7 +16,8 @@ class S3FDExtractor(object):
|
|||
|
||||
self.model = nnlib.keras.models.load_model ( str(model_path) )
|
||||
|
||||
self.extract ( np.zeros( (1080,1920,3), dtype=np.uint8) )
|
||||
if do_dummy_predict:
|
||||
self.extract ( np.zeros( (640,640,3), dtype=np.uint8) )
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
|
|
@ -76,7 +76,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
self.e = facelib.DLIBExtractor(nnlib.dlib)
|
||||
elif self.type == 'rects-s3fd':
|
||||
nnlib.import_all (device_config)
|
||||
self.e = facelib.S3FDExtractor()
|
||||
self.e = facelib.S3FDExtractor(do_dummy_predict=True)
|
||||
else:
|
||||
raise ValueError ("Wrong type.")
|
||||
|
||||
|
@ -88,7 +88,7 @@ class ExtractSubprocessor(Subprocessor):
|
|||
self.e = facelib.FANExtractor()
|
||||
self.e.__enter__()
|
||||
if self.device_vram >= 2:
|
||||
self.second_pass_e = facelib.S3FDExtractor()
|
||||
self.second_pass_e = facelib.S3FDExtractor(do_dummy_predict=False)
|
||||
self.second_pass_e.__enter__()
|
||||
else:
|
||||
self.second_pass_e = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user