fix error in model saving

This commit is contained in:
iperov 2021-08-19 23:18:04 +04:00
parent 26c83f6e35
commit 56e70edc46

View File

@ -46,7 +46,9 @@ class Saveable():
raise Exception("name must be defined.")
name = self.name
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
for w in weights:
w_val = nn.tf_sess.run (w).copy()
w_name_split = w.name.split('/', 1)
if name != w_name_split[0]:
raise Exception("weight first name != Saveable.name")
@ -97,10 +99,10 @@ class Saveable():
nn.batch_set_value(tuples)
except:
return False
return True
def init_weights(self):
nn.init_weights(self.get_weights())
nn.Saveable = Saveable