Switch Transformers: remove overwritten beam sample test (#25458)
This commit is contained in:
@@ -37,7 +37,6 @@ if is_torch_available():
|
|||||||
SwitchTransformersModel,
|
SwitchTransformersModel,
|
||||||
SwitchTransformersTop1Router,
|
SwitchTransformersTop1Router,
|
||||||
)
|
)
|
||||||
from transformers.generation import BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput
|
|
||||||
from transformers.models.switch_transformers.modeling_switch_transformers import (
|
from transformers.models.switch_transformers.modeling_switch_transformers import (
|
||||||
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SWITCH_TRANSFORMERS_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
load_balancing_loss_func,
|
load_balancing_loss_func,
|
||||||
@@ -613,101 +612,6 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
self.model_tester.create_and_check_decoder_model_attention_mask_past(*config_and_inputs)
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_beam_sample_generate_dict_output(self):
|
|
||||||
r"""
|
|
||||||
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
|
|
||||||
"""
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
# disable cache
|
|
||||||
config.use_cache = False
|
|
||||||
|
|
||||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
|
||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
|
||||||
config.eos_token_id = None
|
|
||||||
config.forced_eos_token_id = None
|
|
||||||
|
|
||||||
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
|
||||||
|
|
||||||
num_return_sequences = 2
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
|
||||||
input_ids.shape[0] * num_return_sequences, max_length
|
|
||||||
)
|
|
||||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
|
||||||
|
|
||||||
output_beam_sample, output_generate = self._beam_sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
beam_scorer=beam_scorer,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
logits_warper=logits_warper,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
output_scores=True,
|
|
||||||
output_hidden_states=True,
|
|
||||||
output_attentions=True,
|
|
||||||
return_dict_in_generate=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
|
|
||||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
|
||||||
else:
|
|
||||||
self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
|
|
||||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
|
||||||
|
|
||||||
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_beam_sample_generate(self):
|
|
||||||
r"""
|
|
||||||
This test needs to be overriden with a larger model since it fails for very small models due to precision issues.
|
|
||||||
"""
|
|
||||||
for model_class in self.all_generative_model_classes:
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
|
||||||
|
|
||||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
|
||||||
# shorter than `max_length` can be generated which could lead to flaky circle ci
|
|
||||||
# failures if the top `num_return_sequences` beams are all shorter than the longest beam
|
|
||||||
config.eos_token_id = None
|
|
||||||
config.forced_eos_token_id = None
|
|
||||||
|
|
||||||
logits_warper_kwargs, logits_warper = self._get_warper_and_kwargs(num_beams=2)
|
|
||||||
|
|
||||||
model = model_class.from_pretrained("google/switch-base-8").to(torch_device).eval()
|
|
||||||
|
|
||||||
# check `generate()` and `beam_search()` are equal
|
|
||||||
# change `num_return_sequences = 2` but not for `beam_scorer`
|
|
||||||
num_return_sequences = 2
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
beam_kwargs, beam_scorer = self._get_beam_scorer_and_kwargs(
|
|
||||||
input_ids.shape[0] * num_return_sequences, max_length
|
|
||||||
)
|
|
||||||
beam_kwargs["num_return_sequences"] = num_return_sequences
|
|
||||||
|
|
||||||
output_generate, output_beam_sample = self._beam_sample_generate(
|
|
||||||
model=model,
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=num_return_sequences,
|
|
||||||
beam_scorer=beam_scorer,
|
|
||||||
beam_kwargs=beam_kwargs,
|
|
||||||
logits_warper=logits_warper,
|
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertListEqual(output_generate.tolist(), output_beam_sample.tolist())
|
|
||||||
|
|
||||||
def test_decoder_model_past_with_3d_attn_mask(self):
|
def test_decoder_model_past_with_3d_attn_mask(self):
|
||||||
(
|
(
|
||||||
config,
|
config,
|
||||||
|
|||||||
Reference in New Issue
Block a user