From e4682de6358f9b9cefb73683588e588e4d9154f7 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Mon, 15 Jul 2024 18:49:37 +0100 Subject: [PATCH] Masking: remove flakiness from test (#31939) --- tests/models/whisper/test_modeling_whisper.py | 3 --- tests/test_modeling_common.py | 3 --- 2 files changed, 6 deletions(-) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index dcb495d95a..5fc66f9a20 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -1571,9 +1571,6 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi out_last_tokens = logits[:, -1, :] # last tokens in each batch line out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens - # comparing greedily-chosen tokens: - assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices) - # comparing softmax-normalized logits: normalized_0 = torch.nn.functional.softmax(out_last_tokens) normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0ed3cee3c5..a73417e416 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4486,9 +4486,6 @@ class ModelTesterMixin: out_last_tokens = logits[:, -1, :] # last tokens in each batch line out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens - # comparing greedily-chosen tokens: - assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices) - # comparing softmax-normalized logits: normalized_0 = F.softmax(out_last_tokens) normalized_1 = F.softmax(out_shared_prefix_last_tokens)