fixing optimization

This commit is contained in:
thomwolf
2018-11-03 17:38:15 +01:00
parent 852e4b3c00
commit 088ad45888
4 changed files with 85 additions and 49 deletions

View File

@@ -38,10 +38,16 @@ class OptimizationTest(tf.test.TestCase):
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init_op)
for _ in range(100):
np_w = sess.run(w)
np_loss = sess.run(loss)
np_grad = sess.run(grads)[0]
for i in range(100):
print(i)
sess.run(train_op)
w_np = sess.run(w)
self.assertAllClose(w_np.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
np_w = sess.run(w)
np_loss = sess.run(loss)
np_grad = sess.run(grads)[0]
self.assertAllClose(np_w.flat, [0.4, 0.2, -0.5], rtol=1e-2, atol=1e-2)
if __name__ == "__main__":