This commit is contained in:
thomwolf
2019-10-11 15:55:01 +02:00
parent 0f9fc4fbde
commit 1f5d9513d8
2 changed files with 5 additions and 2 deletions

View File

@@ -161,7 +161,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"outputs": outputs.numpy(), "outputs": outputs.numpy(),
} }
model.config.mem_len = 0 config.mem_len = 0
model = TFXLNetModel(config)
no_mems_outputs = model(inputs) no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual(len(no_mems_outputs), 1)

View File

@@ -150,7 +150,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"outputs": outputs, "outputs": outputs,
} }
model.config.mem_len = 0 config.mem_len = 0
model = XLNetModel(config)
model.eval()
no_mems_outputs = model(input_ids_1) no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual(len(no_mems_outputs), 1)