Higher tolerance for past testing in TF T5 (#3844)

This commit is contained in:
Patrick von Platen
2020-04-17 17:26:16 +02:00
committed by GitHub
parent d13eca11e2
commit 1d4a35b396

View File

@@ -177,7 +177,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask
@@ -221,7 +221,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-3)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()