diff --git a/transformers/tests/modeling_tf_xlnet_test.py b/transformers/tests/modeling_tf_xlnet_test.py index 2c80c4fedb..12a8fbe36f 100644 --- a/transformers/tests/modeling_tf_xlnet_test.py +++ b/transformers/tests/modeling_tf_xlnet_test.py @@ -161,7 +161,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester): "outputs": outputs.numpy(), } - model.config.mem_len = 0 + config.mem_len = 0 + model = TFXLNetModel(config) no_mems_outputs = model(inputs) self.parent.assertEqual(len(no_mems_outputs), 1) diff --git a/transformers/tests/modeling_xlnet_test.py b/transformers/tests/modeling_xlnet_test.py index 293dffabf6..d97ea6a425 100644 --- a/transformers/tests/modeling_xlnet_test.py +++ b/transformers/tests/modeling_xlnet_test.py @@ -150,7 +150,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester): "outputs": outputs, } - model.config.mem_len = 0 + config.mem_len = 0 + model = XLNetModel(config) + model.eval() no_mems_outputs = model(input_ids_1) self.parent.assertEqual(len(no_mems_outputs), 1)