fix test
This commit is contained in:
@@ -161,7 +161,8 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
|
|||||||
"outputs": outputs.numpy(),
|
"outputs": outputs.numpy(),
|
||||||
}
|
}
|
||||||
|
|
||||||
model.config.mem_len = 0
|
config.mem_len = 0
|
||||||
|
model = TFXLNetModel(config)
|
||||||
no_mems_outputs = model(inputs)
|
no_mems_outputs = model(inputs)
|
||||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||||
|
|
||||||
|
|||||||
@@ -150,7 +150,9 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
|
|||||||
"outputs": outputs,
|
"outputs": outputs,
|
||||||
}
|
}
|
||||||
|
|
||||||
model.config.mem_len = 0
|
config.mem_len = 0
|
||||||
|
model = XLNetModel(config)
|
||||||
|
model.eval()
|
||||||
no_mems_outputs = model(input_ids_1)
|
no_mems_outputs = model(input_ids_1)
|
||||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user