From 1f5d9513d862bce15c17f88580ea316ea1bc0545 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 11 Oct 2019 15:55:01 +0200 Subject: [PATCH] fix test --- transformers/tests/modeling_tf_xlnet_test.py | 3 ++- transformers/tests/modeling_xlnet_test.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/transformers/tests/modeling_tf_xlnet_test.py b/transformers/tests/modeling_tf_xlnet_test.py index 2c80c4fedb..12a8fbe36f 100644 --- a/transformers/tests/modeling_tf_xlnet_test.py +++ b/transformers/tests/modeling_tf_xlnet_test.py @@ -161,7 +161,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): "outputs": outputs.numpy(), } - model.config.mem_len = 0 + config.mem_len = 0 + model = TFXLNetModel(config) no_mems_outputs = model(inputs) self.parent.assertEqual(len(no_mems_outputs), 1) diff --git a/transformers/tests/modeling_xlnet_test.py b/transformers/tests/modeling_xlnet_test.py index 293dffabf6..d97ea6a425 100644 --- a/transformers/tests/modeling_xlnet_test.py +++ b/transformers/tests/modeling_xlnet_test.py @@ -150,7 +150,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): "outputs": outputs, } - model.config.mem_len = 0 + config.mem_len = 0 + model = XLNetModel(config) + model.eval() no_mems_outputs = model(input_ids_1) self.parent.assertEqual(len(no_mems_outputs), 1)