fix support for v1/v2

This commit is contained in:
iperov 2021-01-01 17:59:57 +04:00
parent 8ff34be5e4
commit 4f2efd7985

View File

@ -77,7 +77,14 @@ class nn():
io.log_info("Caching GPU kernels...")
import tensorflow
if tensorflow.VERSION[0] == '2':
tf_version = getattr(tensorflow,'VERSION', None)
if tf_version is None:
tf_version = tensorflow.version.GIT_VERSION
if tf_version[0] == 'v':
tf_version = tf_version[1:]
if tf_version[0] == '2':
tf = tensorflow.compat.v1
else:
tf = tensorflow
@ -87,7 +94,7 @@ class nn():
tf_logger = logging.getLogger('tensorflow')
tf_logger.setLevel(logging.ERROR)
if tensorflow.VERSION[0] == '2':
if tf_version[0] == '2':
tf.disable_v2_behavior()
nn.tf = tf