Use Python 3.9 syntax in tests (#37343)
Signed-off-by: cyy <cyyever@outlook.com>
This commit is contained in:
@@ -179,7 +179,7 @@ class FlaxBartModelTester:
|
||||
|
||||
outputs = model.decode(decoder_input_ids, encoder_outputs)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
def check_use_cache_forward_with_attn_mask(self, model_class_name, config, inputs_dict):
|
||||
@@ -225,7 +225,7 @@ class FlaxBartModelTester:
|
||||
|
||||
outputs = model.decode(decoder_input_ids, encoder_outputs, decoder_attention_mask=decoder_attention_mask)
|
||||
|
||||
diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
|
||||
diff = np.max(np.abs(outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5]))
|
||||
self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user