From 4692d2619433f1eb064f3da4f3573f060a115eac Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Fri, 11 Aug 2023 13:16:01 +0100 Subject: [PATCH] Switch Transformers: remove overwritten beam sample test (#25458) --- .../test_modeling_switch_transformers.py | 96 ------------------- 1 file changed, 96 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index ae9966a1db..54e17b91b7 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -37,7 +37,6 @@ if is_torch_available(): SwitchTransformersModel, SwitchTransformersTop1Router, ) - from transformers.generation import BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput from transformers.models.switch_transformers.modeling_switch_transformers import ( SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST, load_balancing_loss_func, @@ -613,101 +612,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs) - @slow - def test_beam_sample_generate_dict_output(self): - r""" - This test needs to be overriden with a larger model since it fails for very small models due to precision issues. - """ - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # disable cache - config.use_cache = False - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - - model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval() - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) - - num_return_sequences = 2 - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( - input_ids.shape[0] * num_return_sequences, max_length - ) - beam_kwargs["num_return_sequences"] = num_return_sequences - - output_beam_sample, output_generate = self._beam_sample_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_return_sequences=num_return_sequences, - beam_scorer=beam_scorer, - beam_kwargs=beam_kwargs, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) - - if model.config.is_encoder_decoder: - self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput) - self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput) - else: - self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput) - self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput) - - self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) - - @slow - def test_beam_sample_generate(self): - r""" - This test needs to be overriden with a larger model since it fails for very small models due to precision issues. - """ - for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() - - # It is important set set the eos_token_id to None to ensure that no sequences - # shorter than `max_length` can be generated which could lead to flaky circle ci - # failures if the top `num_return_sequences` beams are all shorter than the longest beam - config.eos_token_id = None - config.forced_eos_token_id = None - - logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2) - - model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval() - - # check `generate()` and `beam_search()` are equal - # change `num_return_sequences = 2` but not for `beam_scorer` - num_return_sequences = 2 - if model.config.is_encoder_decoder: - max_length = 4 - beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs( - input_ids.shape[0] * num_return_sequences, max_length - ) - beam_kwargs["num_return_sequences"] = num_return_sequences - - output_generate, output_beam_sample = self._beam_sample_generate( - model=model, - input_ids=input_ids, - attention_mask=attention_mask, - max_length=max_length, - num_return_sequences=num_return_sequences, - beam_scorer=beam_scorer, - beam_kwargs=beam_kwargs, - logits_warper=logits_warper, - logits_warper_kwargs=logits_warper_kwargs, - ) - - self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist()) - def test_decoder_model_past_with_3d_attn_mask(self): ( config,