adding option to desactivate past/memory outputs
This commit is contained in:
@@ -161,6 +161,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
||||
"outputs": outputs.numpy(),
|
||||
}
|
||||
|
||||
model.config.mem_len = 0
|
||||
no_mems_outputs = model(inputs)
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].shape),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
|
||||
@@ -150,6 +150,10 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
||||
"outputs": outputs,
|
||||
}
|
||||
|
||||
model.config.mem_len = 0
|
||||
no_mems_outputs = model(input_ids_1)
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].size()),
|
||||
[self.batch_size, self.seq_length, self.hidden_size])
|
||||
|
||||
Reference in New Issue
Block a user