adding option to desactivate past/memory outputs

This commit is contained in:
thomwolf
2019-10-11 15:47:08 +02:00
parent 2a4fef837a
commit 0f9fc4fbde
8 changed files with 93 additions and 55 deletions

View File

@@ -161,6 +161,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
"outputs": outputs.numpy(),
}
model.config.mem_len = 0
no_mems_outputs = model(inputs)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].shape),
[self.batch_size, self.seq_length, self.hidden_size])

View File

@@ -150,6 +150,10 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
"outputs": outputs,
}
model.config.mem_len = 0
no_mems_outputs = model(input_ids_1)
self.parent.assertEqual(len(no_mems_outputs), 1)
self.parent.assertListEqual(
list(result["outputs"].size()),
[self.batch_size, self.seq_length, self.hidden_size])