Generation tests: don't rely on main input name (#34228)
* don't rely on main input name * update
This commit is contained in:
committed by
GitHub
parent
816f442496
commit
ca541bd4f4
@@ -410,7 +410,6 @@ class GenerationTesterMixin:
|
|||||||
def test_greedy_generate(self):
|
def test_greedy_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
|
output_generate = self._greedy_generate(model=model, inputs_dict=inputs_dict)
|
||||||
@@ -418,7 +417,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
@@ -444,7 +443,9 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||||
@@ -478,7 +479,9 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
||||||
|
|
||||||
@@ -486,7 +489,6 @@ class GenerationTesterMixin:
|
|||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
|
output_generate = self._sample_generate(model=model, inputs_dict=inputs_dict, num_return_sequences=1)
|
||||||
@@ -494,7 +496,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
@@ -521,7 +523,9 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||||
@@ -532,7 +536,6 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_search_generate(self):
|
def test_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
@@ -542,7 +545,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_search_generate_dict_output(self):
|
def test_beam_search_generate_dict_output(self):
|
||||||
@@ -569,7 +572,9 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
@@ -609,7 +614,9 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate,
|
output_generate,
|
||||||
@@ -647,7 +654,6 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_sample_generate(self):
|
def test_beam_sample_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@@ -660,7 +666,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
# for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly
|
# for VLMs inputs embeds won't match input ids unless images are encoded and merged with ids properly
|
||||||
# no quick fix available, since obtaining image embeddings step is very model-specific
|
# no quick fix available, since obtaining image embeddings step is very model-specific
|
||||||
@@ -712,7 +718,9 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||||
@@ -746,7 +754,6 @@ class GenerationTesterMixin:
|
|||||||
def test_group_beam_search_generate(self):
|
def test_group_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
# check `generate()` and `group_beam_search()` are equal
|
# check `generate()` and `group_beam_search()` are equal
|
||||||
@@ -759,7 +766,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
# check `group_beam_search` for higher than 1 `num_return_sequences`
|
# check `group_beam_search` for higher than 1 `num_return_sequences`
|
||||||
num_return_sequences = 2
|
num_return_sequences = 2
|
||||||
@@ -772,7 +779,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_group_beam_search_generate_dict_output(self):
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
@@ -799,7 +806,9 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
@@ -814,7 +823,6 @@ class GenerationTesterMixin:
|
|||||||
def test_constrained_beam_search_generate(self):
|
def test_constrained_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
@@ -838,7 +846,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
for generation_output in output_generate:
|
for generation_output in output_generate:
|
||||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||||
@@ -862,7 +870,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
for generation_output in output_generate:
|
for generation_output in output_generate:
|
||||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||||
@@ -903,7 +911,9 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
@@ -923,7 +933,6 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest(reason="Won't fix: old model with different cache format")
|
self.skipTest(reason="Won't fix: old model with different cache format")
|
||||||
|
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
main_input = inputs_dict[model_class.main_input_name]
|
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -940,7 +949,7 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1])
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
@@ -975,7 +984,9 @@ class GenerationTesterMixin:
|
|||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + main_input.shape[-1])
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
self._check_outputs(output_generate, main_input, model.config, use_cache=True)
|
||||||
|
|
||||||
@@ -2035,8 +2046,14 @@ class GenerationTesterMixin:
|
|||||||
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
self.assertTrue("GenerationMixin" in str(model_class.__bases__))
|
||||||
|
|
||||||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
||||||
|
# we can be sure what is batch size from main input but seq length depends on model type and whether input is text/audio/image
|
||||||
|
# so we infer actual text seq length from model_tester, same was as it is done in `test_modeling_common.py` tests`
|
||||||
batch_size = main_input.shape[0]
|
batch_size = main_input.shape[0]
|
||||||
seq_length = main_input.shape[-1]
|
|
||||||
|
seq_length = getattr(self.model_tester, "seq_length", None)
|
||||||
|
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
||||||
|
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
||||||
|
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
config = config.text_config if hasattr(config, "text_config") else config
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
num_sequences_in_output = batch_size * num_return_sequences
|
||||||
|
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class ReformerModelTester:
|
|||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=13,
|
||||||
seq_length=32,
|
seq_length=32,
|
||||||
|
text_seq_length=None,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
use_input_mask=True,
|
use_input_mask=True,
|
||||||
@@ -128,6 +129,7 @@ class ReformerModelTester:
|
|||||||
self.attn_layers = attn_layers
|
self.attn_layers = attn_layers
|
||||||
self.pad_token_id = pad_token_id
|
self.pad_token_id = pad_token_id
|
||||||
self.hash_seed = hash_seed
|
self.hash_seed = hash_seed
|
||||||
|
self.text_seq_length = text_seq_length or seq_length
|
||||||
|
|
||||||
attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length
|
attn_chunk_length = local_attn_chunk_length if local_attn_chunk_length is not None else lsh_attn_chunk_length
|
||||||
num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after
|
num_chunks_after = local_num_chunks_after if local_num_chunks_after is not None else lsh_num_chunks_after
|
||||||
@@ -608,7 +610,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
test_sequence_classification_problem_types = True
|
test_sequence_classification_problem_types = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = ReformerModelTester(self)
|
self.model_tester = ReformerModelTester(self, text_seq_length=16)
|
||||||
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
@@ -689,7 +691,7 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
||||||
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
||||||
original_sequence_length = self.model_tester.seq_length
|
original_sequence_length = self.model_tester.seq_length
|
||||||
self.model_tester.seq_length = 16
|
self.model_tester.seq_length = self.model_tester.text_seq_length
|
||||||
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
test_inputs = super().prepare_config_and_inputs_for_generate(*args, **kwargs)
|
||||||
self.model_tester.seq_length = original_sequence_length
|
self.model_tester.seq_length = original_sequence_length
|
||||||
return test_inputs
|
return test_inputs
|
||||||
|
|||||||
@@ -618,14 +618,6 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1):
|
|
||||||
# In this model, the index of `batch_size` and `sequence_length`` in `main_input` is different: they are the
|
|
||||||
# first two dimensions of the tensor.
|
|
||||||
main_input = main_input[:, :, 0]
|
|
||||||
super()._check_outputs(
|
|
||||||
output, main_input, config, use_cache=use_cache, num_return_sequences=num_return_sequences
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
if not self.test_torchscript:
|
if not self.test_torchscript:
|
||||||
self.skipTest(reason="test_torchscript is set to False")
|
self.skipTest(reason="test_torchscript is set to False")
|
||||||
|
|||||||
Reference in New Issue
Block a user