Do not remove half seq length in generation tests (#30016)
* remove seq length from generation tests * style and quality * [test_all] & PR suggestion Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/generation/test_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * [test all] remove unused variables --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
b4fd49b6c5
commit
b1cd48740e
@@ -82,43 +82,35 @@ class GenerationTesterMixin:
|
|||||||
model_tester = None
|
model_tester = None
|
||||||
all_generative_model_classes = ()
|
all_generative_model_classes = ()
|
||||||
input_name = "input_ids"
|
input_name = "input_ids"
|
||||||
|
max_new_tokens = 3
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
input_ids = inputs_dict[self.input_name]
|
input_ids = inputs_dict[self.input_name]
|
||||||
|
|
||||||
# cut to half length & take max batch_size 3
|
input_ids = input_ids[:batch_size]
|
||||||
sequence_length = input_ids.shape[-1] // 2
|
|
||||||
input_ids = input_ids[:batch_size, :sequence_length]
|
|
||||||
|
|
||||||
# generate max 3 tokens
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
else:
|
|
||||||
max_length = input_ids.shape[-1] + 3
|
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
if isinstance(config.eos_token_id, int):
|
if isinstance(config.eos_token_id, int):
|
||||||
config.eos_token_id = [config.eos_token_id]
|
config.eos_token_id = [config.eos_token_id]
|
||||||
config.pad_token_id = config.eos_token_id[0]
|
config.pad_token_id = config.eos_token_id[0]
|
||||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :sequence_length]
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||||
|
|
||||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||||
# shorter than `max_length` can be generated
|
# shorter than `max_length` can be generated
|
||||||
config.eos_token_id = None
|
config.eos_token_id = None
|
||||||
config.forced_eos_token_id = None
|
config.forced_eos_token_id = None
|
||||||
|
|
||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
def _get_logits_processor_and_warper_kwargs(
|
||||||
input_length,
|
input_length,
|
||||||
forced_bos_token_id=None,
|
forced_bos_token_id=None,
|
||||||
forced_eos_token_id=None,
|
forced_eos_token_id=None,
|
||||||
max_length=None,
|
|
||||||
):
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
|
||||||
"bad_words_ids": [[1, 0]],
|
"bad_words_ids": [[1, 0]],
|
||||||
"repetition_penalty": 1.2,
|
"repetition_penalty": 1.2,
|
||||||
"remove_invalid_values": True,
|
"remove_invalid_values": True,
|
||||||
@@ -185,7 +177,6 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -196,7 +187,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
max_length=max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
@@ -204,7 +194,7 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
@@ -221,7 +211,6 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
logits_warper_kwargs,
|
logits_warper_kwargs,
|
||||||
process_kwargs,
|
process_kwargs,
|
||||||
@@ -237,7 +226,7 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
num_return_sequences=num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
@@ -256,7 +245,6 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_process_kwargs,
|
logits_process_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
@@ -269,7 +257,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -287,7 +275,6 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_warper_kwargs,
|
logits_warper_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
@@ -301,7 +288,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -319,7 +306,6 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_process_kwargs,
|
logits_process_kwargs,
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
@@ -332,7 +318,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -350,7 +336,6 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
constraints,
|
constraints,
|
||||||
beam_kwargs,
|
beam_kwargs,
|
||||||
logits_process_kwargs,
|
logits_process_kwargs,
|
||||||
@@ -364,7 +349,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_logits=output_logits,
|
output_logits=output_logits,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
@@ -383,7 +368,6 @@ class GenerationTesterMixin:
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_logits=False,
|
output_logits=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
@@ -399,7 +383,6 @@ class GenerationTesterMixin:
|
|||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
max_length=max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
@@ -407,7 +390,7 @@ class GenerationTesterMixin:
|
|||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
@@ -422,18 +405,19 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
def test_greedy_generate(self):
|
def test_greedy_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(model=model, input_ids=input_ids, attention_mask=attention_mask)
|
||||||
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@@ -441,7 +425,6 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -450,20 +433,21 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GreedySearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
|
||||||
self._check_outputs(output_generate, input_ids, model.config)
|
self._check_outputs(output_generate, input_ids, model.config)
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest("This model doesn't support caching")
|
self.skipTest("This model doesn't support caching")
|
||||||
@@ -475,7 +459,6 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -483,57 +466,54 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
max_length=max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
logits_warper_kwargs=logits_warper_kwargs,
|
||||||
process_kwargs=process_kwargs,
|
process_kwargs=process_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
forced_bos_token_id=model.config.forced_bos_token_id,
|
forced_bos_token_id=model.config.forced_bos_token_id,
|
||||||
forced_eos_token_id=model.config.forced_eos_token_id,
|
forced_eos_token_id=model.config.forced_eos_token_id,
|
||||||
max_length=max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=2,
|
num_return_sequences=2,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
logits_warper_kwargs=logits_warper_kwargs,
|
||||||
process_kwargs=process_kwargs,
|
process_kwargs=process_kwargs,
|
||||||
@@ -545,30 +525,28 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, SampleEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
|
||||||
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2)
|
self._check_outputs(output_generate, input_ids, model.config, num_return_sequences=2)
|
||||||
|
|
||||||
def test_beam_search_generate(self):
|
def test_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
config.forced_bos_token_id,
|
config.forced_bos_token_id,
|
||||||
config.forced_eos_token_id,
|
config.forced_eos_token_id,
|
||||||
max_length,
|
|
||||||
)
|
)
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
@@ -576,36 +554,33 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
def test_beam_search_generate_dict_output(self):
|
def test_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# disable cache
|
# disable cache
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
config.forced_bos_token_id,
|
config.forced_bos_token_id,
|
||||||
config.forced_eos_token_id,
|
config.forced_eos_token_id,
|
||||||
max_length,
|
|
||||||
)
|
)
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
output_generate = self._beam_search_generate(
|
output_generate = self._beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
@@ -615,15 +590,16 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||||
)
|
)
|
||||||
@@ -631,20 +607,16 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_search_generate_dict_outputs_use_cache(self):
|
def test_beam_search_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest("This model doesn't support caching")
|
self.skipTest("This model doesn't support caching")
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
config.forced_bos_token_id,
|
config.forced_bos_token_id,
|
||||||
config.forced_eos_token_id,
|
config.forced_eos_token_id,
|
||||||
max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@@ -656,7 +628,6 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
@@ -666,7 +637,10 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate, input_ids, model.config, use_cache=True, num_return_sequences=beam_kwargs["num_beams"]
|
||||||
)
|
)
|
||||||
@@ -681,7 +655,7 @@ class GenerationTesterMixin:
|
|||||||
if model_class._no_split_modules is None:
|
if model_class._no_split_modules is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).eval()
|
model = model_class(config).eval()
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
@@ -691,32 +665,32 @@ class GenerationTesterMixin:
|
|||||||
new_model.generate(
|
new_model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
num_beams=2,
|
num_beams=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_beam_sample_generate(self):
|
def test_beam_sample_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
output_generate = self._beam_sample_generate(
|
output_generate = self._beam_sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
logits_warper_kwargs=logits_warper_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
|
if "inputs_embeds" in set(inspect.signature(model.prepare_inputs_for_generation).parameters):
|
||||||
input_embeds = model.get_input_embeddings()(input_ids)
|
input_embeds = model.get_input_embeddings()(input_ids)
|
||||||
beam_kwargs.update({"inputs_embeds": input_embeds})
|
beam_kwargs.update({"inputs_embeds": input_embeds})
|
||||||
@@ -724,7 +698,6 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
logits_warper_kwargs=logits_warper_kwargs,
|
||||||
)
|
)
|
||||||
@@ -733,23 +706,19 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
def test_beam_sample_generate_dict_output(self):
|
def test_beam_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# disable cache
|
# disable cache
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
_, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(input_ids.shape[-1])
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
output_generate = self._beam_sample_generate(
|
output_generate = self._beam_sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
logits_warper_kwargs=logits_warper_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
@@ -760,21 +729,22 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
config, _, _, max_length = self._get_input_ids_and_config()
|
config, _, _ = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# if no bos token id => cannot generate from None
|
# if no bos token id => cannot generate from None
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
@@ -788,22 +758,20 @@ class GenerationTesterMixin:
|
|||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
|
output_ids_generate = model.generate(
|
||||||
|
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||||
|
)
|
||||||
self.assertIsNotNone(output_ids_generate)
|
self.assertIsNotNone(output_ids_generate)
|
||||||
|
|
||||||
def test_group_beam_search_generate(self):
|
def test_group_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
config.forced_bos_token_id,
|
config.forced_bos_token_id,
|
||||||
config.forced_eos_token_id,
|
config.forced_eos_token_id,
|
||||||
max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# check `generate()` and `group_beam_search()` are equal
|
# check `generate()` and `group_beam_search()` are equal
|
||||||
@@ -812,11 +780,13 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
)
|
)
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
# check `group_beam_search` for higher than 1 `num_return_sequences`
|
# check `group_beam_search` for higher than 1 `num_return_sequences`
|
||||||
num_return_sequences = 2
|
num_return_sequences = 2
|
||||||
@@ -825,26 +795,24 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
)
|
)
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
def test_group_beam_search_generate_dict_output(self):
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 4
|
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
config.forced_bos_token_id,
|
config.forced_bos_token_id,
|
||||||
config.forced_eos_token_id,
|
config.forced_eos_token_id,
|
||||||
max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||||
@@ -852,7 +820,6 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
@@ -862,15 +829,16 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||||
)
|
)
|
||||||
@@ -879,16 +847,14 @@ class GenerationTesterMixin:
|
|||||||
@is_flaky()
|
@is_flaky()
|
||||||
def test_constrained_beam_search_generate(self):
|
def test_constrained_beam_search_generate(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
max_length = 20
|
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
config.forced_bos_token_id,
|
config.forced_bos_token_id,
|
||||||
config.forced_eos_token_id,
|
config.forced_eos_token_id,
|
||||||
max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
@@ -905,12 +871,16 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
)
|
)
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
for generation_output in output_generate:
|
for generation_output in output_generate:
|
||||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||||
|
|
||||||
@@ -921,39 +891,37 @@ class GenerationTesterMixin:
|
|||||||
PhrasalConstraint(force_tokens),
|
PhrasalConstraint(force_tokens),
|
||||||
]
|
]
|
||||||
|
|
||||||
max_length = 20
|
|
||||||
beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
|
beam_kwargs = self._get_constrained_beam_kwargs(num_return_sequences=2)
|
||||||
|
|
||||||
output_generate = self._constrained_beam_search_generate(
|
output_generate = self._constrained_beam_search_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
)
|
)
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
for generation_output in output_generate:
|
for generation_output in output_generate:
|
||||||
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
self._check_sequence_inside_sequence(force_tokens, generation_output)
|
||||||
|
|
||||||
def test_constrained_beam_search_generate_dict_output(self):
|
def test_constrained_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# disable cache
|
# disable cache
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
if model.config.is_encoder_decoder:
|
|
||||||
max_length = 20
|
|
||||||
|
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
||||||
input_ids.shape[-1],
|
input_ids.shape[-1],
|
||||||
config.forced_bos_token_id,
|
config.forced_bos_token_id,
|
||||||
config.forced_eos_token_id,
|
config.forced_eos_token_id,
|
||||||
max_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Sample constraints
|
# Sample constraints
|
||||||
@@ -969,7 +937,6 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
constraints=constraints,
|
constraints=constraints,
|
||||||
beam_kwargs=beam_kwargs,
|
beam_kwargs=beam_kwargs,
|
||||||
logits_process_kwargs=logits_process_kwargs,
|
logits_process_kwargs=logits_process_kwargs,
|
||||||
@@ -981,15 +948,16 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateBeamEncoderDecoderOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput)
|
||||||
else:
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateBeamDecoderOnlyOutput)
|
||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
|
||||||
self._check_outputs(
|
self._check_outputs(
|
||||||
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
output_generate, input_ids, model.config, num_return_sequences=beam_kwargs["num_beams"]
|
||||||
)
|
)
|
||||||
@@ -1000,7 +968,7 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
self.skipTest("Won't fix: old model with different cache format")
|
self.skipTest("Won't fix: old model with different cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1011,9 +979,12 @@ class GenerationTesterMixin:
|
|||||||
# test old generation output for backwards compatibility
|
# test old generation output for backwards compatibility
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._contrastive_generate(
|
output_generate = self._contrastive_generate(
|
||||||
model=model, input_ids=input_ids, attention_mask=attention_mask, max_length=max_length
|
model=model, input_ids=input_ids, attention_mask=attention_mask
|
||||||
)
|
)
|
||||||
self.assertTrue(output_generate.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
|
|
||||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
@@ -1021,7 +992,7 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||||
self.skipTest("Won't fix: old model with different cache format")
|
self.skipTest("Won't fix: old model with different cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1034,7 +1005,6 @@ class GenerationTesterMixin:
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_logits=True,
|
output_logits=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -1042,7 +1012,10 @@ class GenerationTesterMixin:
|
|||||||
return_dict_in_generate=True,
|
return_dict_in_generate=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(output_generate.sequences.shape[-1] == max_length)
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
else:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + input_ids.shape[-1])
|
||||||
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
self._check_outputs(output_generate, input_ids, model.config, use_cache=True)
|
||||||
|
|
||||||
def test_contrastive_generate_low_memory(self):
|
def test_contrastive_generate_low_memory(self):
|
||||||
@@ -1053,7 +1026,7 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
|
||||||
self.skipTest("TODO: fix me")
|
self.skipTest("TODO: fix me")
|
||||||
|
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: contrastive search only works with cache on at the moment.
|
# NOTE: contrastive search only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1070,7 +1043,7 @@ class GenerationTesterMixin:
|
|||||||
top_k=4,
|
top_k=4,
|
||||||
penalty_alpha=0.6,
|
penalty_alpha=0.6,
|
||||||
low_memory=True,
|
low_memory=True,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1079,7 +1052,7 @@ class GenerationTesterMixin:
|
|||||||
top_k=4,
|
top_k=4,
|
||||||
penalty_alpha=0.6,
|
penalty_alpha=0.6,
|
||||||
low_memory=False,
|
low_memory=False,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
self.assertListEqual(low_output.tolist(), high_output.tolist())
|
||||||
@@ -1102,7 +1075,7 @@ class GenerationTesterMixin:
|
|||||||
]
|
]
|
||||||
):
|
):
|
||||||
self.skipTest("May fix in the future: need model-specific fixes")
|
self.skipTest("May fix in the future: need model-specific fixes")
|
||||||
config, input_ids, _, _ = self._get_input_ids_and_config(batch_size=2)
|
config, input_ids, _ = self._get_input_ids_and_config(batch_size=2)
|
||||||
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
# batch_size=1 is ok, but batch_size>1 will cause non-identical output
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
@@ -1150,7 +1123,7 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest("May fix in the future: need model-specific fixes")
|
self.skipTest("May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1213,7 +1186,7 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest("May fix in the future: need model-specific fixes")
|
self.skipTest("May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1273,7 +1246,7 @@ class GenerationTesterMixin:
|
|||||||
self.skipTest("May fix in the future: need model-specific fixes")
|
self.skipTest("May fix in the future: need model-specific fixes")
|
||||||
|
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config(batch_size=1)
|
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||||
|
|
||||||
# NOTE: assisted generation only works with cache on at the moment.
|
# NOTE: assisted generation only works with cache on at the moment.
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
@@ -1311,7 +1284,7 @@ class GenerationTesterMixin:
|
|||||||
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
"""Test designed for encoder-decoder models to ensure the attention head masking is used."""
|
||||||
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
attention_names = ["encoder_attentions", "decoder_attentions", "cross_attentions"]
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
# We want to test only encoder-decoder models
|
# We want to test only encoder-decoder models
|
||||||
if not config.is_encoder_decoder:
|
if not config.is_encoder_decoder:
|
||||||
continue
|
continue
|
||||||
@@ -1358,7 +1331,7 @@ class GenerationTesterMixin:
|
|||||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||||
decoder_only_classes = []
|
decoder_only_classes = []
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, _, _, _ = self._get_input_ids_and_config()
|
config, _, _ = self._get_input_ids_and_config()
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
@@ -1391,7 +1364,7 @@ class GenerationTesterMixin:
|
|||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
for model_class in decoder_only_classes:
|
for model_class in decoder_only_classes:
|
||||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
signature = inspect.signature(model.forward).parameters.keys()
|
signature = inspect.signature(model.forward).parameters.keys()
|
||||||
|
|
||||||
@@ -1485,7 +1458,7 @@ class GenerationTesterMixin:
|
|||||||
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
# When supported, tests that the decoder model can generate from `inputs_embeds` instead of `input_ids`
|
||||||
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
# if fails, you should probably update the `prepare_inputs_for_generation` function
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, input_ids, _, _ = self._get_input_ids_and_config()
|
config, input_ids, _ = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# Ignore:
|
# Ignore:
|
||||||
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
# a) eos (to always output 20 tokens) and pad (so we don't try to infer the attn mask from the input_ids,
|
||||||
@@ -1616,7 +1589,7 @@ class GenerationTesterMixin:
|
|||||||
if not model_class._supports_cache_class:
|
if not model_class._supports_cache_class:
|
||||||
self.skipTest("This model does not support the new cache format")
|
self.skipTest("This model does not support the new cache format")
|
||||||
|
|
||||||
config, input_ids, attention_mask, _ = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
|
||||||
|
|||||||
@@ -299,12 +299,10 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
input_ids = input_ids[:batch_size, :sequence_length]
|
input_ids = input_ids[:batch_size, :sequence_length]
|
||||||
attention_mask = attention_mask[:batch_size, :sequence_length]
|
attention_mask = attention_mask[:batch_size, :sequence_length]
|
||||||
|
|
||||||
# generate max 3 tokens
|
|
||||||
max_length = input_ids.shape[-1] + 3
|
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
config.pad_token_id = config.eos_token_id
|
config.pad_token_id = config.eos_token_id
|
||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = BigBirdPegasusModelTester(self)
|
self.model_tester = BigBirdPegasusModelTester(self)
|
||||||
|
|||||||
@@ -457,6 +457,20 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||||
|
# overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape
|
||||||
|
encoder_expected_shape = (
|
||||||
|
batch_size,
|
||||||
|
config.num_attention_heads,
|
||||||
|
seq_length,
|
||||||
|
self.model_tester.attention_window // 2 * 2 + 1,
|
||||||
|
)
|
||||||
|
self.assertIsInstance(attentions, tuple)
|
||||||
|
self.assertListEqual(
|
||||||
|
[layer_attentions.shape for layer_attentions in attentions],
|
||||||
|
[encoder_expected_shape] * len(attentions),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
def assert_tensors_close(a, b, atol=1e-12, prefix=""):
|
||||||
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
|
||||||
|
|||||||
@@ -752,7 +752,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
|
|
||||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||||
block_len = getattr(self.model_tester, "block_len", None)
|
block_len = getattr(self.model_tester, "block_len", None)
|
||||||
encoder_expected_shape = (batch_size, 1, config.num_attention_heads, block_len, 3 * block_len)
|
encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len)
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_attentions.shape for layer_attentions in attentions],
|
[layer_attentions.shape for layer_attentions in attentions],
|
||||||
@@ -885,7 +885,7 @@ class LongT5TGlobalModelTest(LongT5ModelTest):
|
|||||||
global_seq_length = seq_length // global_block_size
|
global_seq_length = seq_length // global_block_size
|
||||||
encoder_expected_shape = (
|
encoder_expected_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
1,
|
2,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
block_len,
|
block_len,
|
||||||
3 * block_len + global_seq_length,
|
3 * block_len + global_seq_length,
|
||||||
|
|||||||
@@ -245,34 +245,28 @@ class MusicgenDecoderTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
sequence_length = input_ids.shape[-1]
|
sequence_length = input_ids.shape[-1]
|
||||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||||
|
|
||||||
# generate max 3 tokens
|
|
||||||
max_length = input_ids.shape[-1] + 3
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
def _get_logits_processor_and_warper_kwargs(
|
||||||
input_length,
|
input_length,
|
||||||
forced_bos_token_id=None,
|
forced_bos_token_id=None,
|
||||||
forced_eos_token_id=None,
|
forced_eos_token_id=None,
|
||||||
max_length=None,
|
|
||||||
):
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {}
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
|
||||||
}
|
|
||||||
warper_kwargs = {}
|
warper_kwargs = {}
|
||||||
return process_kwargs, warper_kwargs
|
return process_kwargs, warper_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1327,9 +1321,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
input_ids = input_ids[:batch_size, :]
|
input_ids = input_ids[:batch_size, :]
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
|
|
||||||
# generate max 3 tokens
|
return config, input_ids, attention_mask
|
||||||
max_length = 3
|
|
||||||
return config, input_ids, attention_mask, max_length
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen (input / outputs are
|
||||||
# different modalities -> different shapes)
|
# different modalities -> different shapes)
|
||||||
@@ -1338,29 +1330,22 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
remove_invalid_values=True,
|
||||||
**logits_process_kwargs,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1373,10 +1358,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
logits_warper_kwargs,
|
|
||||||
process_kwargs,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
@@ -1388,15 +1370,13 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
input_ids,
|
input_ids,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
num_return_sequences=num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
remove_invalid_values=True,
|
||||||
**logits_warper_kwargs,
|
|
||||||
**process_kwargs,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1407,25 +1387,21 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
input_length,
|
input_length,
|
||||||
forced_bos_token_id=None,
|
forced_bos_token_id=None,
|
||||||
forced_eos_token_id=None,
|
forced_eos_token_id=None,
|
||||||
max_length=None,
|
|
||||||
):
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {}
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
|
||||||
}
|
|
||||||
warper_kwargs = {}
|
warper_kwargs = {}
|
||||||
return process_kwargs, warper_kwargs
|
return process_kwargs, warper_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# disable cache
|
# disable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1439,7 +1415,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@@ -1448,7 +1424,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1459,46 +1434,30 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
# check `generate()` and `sample()` are equal
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
)
|
)
|
||||||
self.assertIsInstance(output_generate, torch.Tensor)
|
self.assertIsInstance(output_generate, torch.Tensor)
|
||||||
|
|
||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# disable cache
|
# disable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=3,
|
num_return_sequences=3,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1508,7 +1467,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||||
|
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
config, _, _, max_length = self._get_input_ids_and_config()
|
config, _, _ = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# if no bos token id => cannot generate from None
|
# if no bos token id => cannot generate from None
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
@@ -1518,7 +1477,9 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
|
output_ids_generate = model.generate(
|
||||||
|
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||||
|
)
|
||||||
self.assertIsNotNone(output_ids_generate)
|
self.assertIsNotNone(output_ids_generate)
|
||||||
|
|
||||||
@require_torch_fp16
|
@require_torch_fp16
|
||||||
@@ -1537,7 +1498,7 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@@ -1545,7 +1506,6 @@ class MusicgenTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
|
|||||||
@@ -246,34 +246,28 @@ class MusicgenMelodyDecoderTest(ModelTesterMixin, GenerationTesterMixin, unittes
|
|||||||
sequence_length = input_ids.shape[-1]
|
sequence_length = input_ids.shape[-1]
|
||||||
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
input_ids = input_ids[: batch_size * config.num_codebooks, :]
|
||||||
|
|
||||||
# generate max 3 tokens
|
|
||||||
max_length = input_ids.shape[-1] + 3
|
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_logits_processor_and_warper_kwargs(
|
def _get_logits_processor_and_warper_kwargs(
|
||||||
input_length,
|
input_length,
|
||||||
forced_bos_token_id=None,
|
forced_bos_token_id=None,
|
||||||
forced_eos_token_id=None,
|
forced_eos_token_id=None,
|
||||||
max_length=None,
|
|
||||||
):
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {}
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
|
||||||
}
|
|
||||||
warper_kwargs = {}
|
warper_kwargs = {}
|
||||||
return process_kwargs, warper_kwargs
|
return process_kwargs, warper_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1309,9 +1303,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
input_ids = input_ids[:batch_size, :]
|
input_ids = input_ids[:batch_size, :]
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long)
|
||||||
|
|
||||||
# generate max 3 tokens
|
return config, input_ids, attention_mask
|
||||||
max_length = 3
|
|
||||||
return config, input_ids, attention_mask, max_length
|
|
||||||
|
|
||||||
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
# override since the `input_ids` cannot be used as the `decoder_input_ids` for musicgen_melody (input / outputs are
|
||||||
# different modalities -> different shapes)
|
# different modalities -> different shapes)
|
||||||
@@ -1320,29 +1312,22 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict_in_generate=False,
|
return_dict_in_generate=False,
|
||||||
):
|
):
|
||||||
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
model_kwargs = {"attention_mask": attention_mask} if attention_mask is not None else {}
|
||||||
output_generate = model.generate(
|
output_generate = model.generate(
|
||||||
input_ids,
|
input_ids,
|
||||||
do_sample=False,
|
do_sample=False,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
remove_invalid_values=True,
|
||||||
**logits_process_kwargs,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1355,10 +1340,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model,
|
model,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
max_length,
|
|
||||||
num_return_sequences,
|
num_return_sequences,
|
||||||
logits_warper_kwargs,
|
|
||||||
process_kwargs,
|
|
||||||
output_scores=False,
|
output_scores=False,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
@@ -1370,15 +1352,13 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
input_ids,
|
input_ids,
|
||||||
do_sample=True,
|
do_sample=True,
|
||||||
num_beams=1,
|
num_beams=1,
|
||||||
max_length=max_length,
|
max_new_tokens=self.max_new_tokens,
|
||||||
num_return_sequences=num_return_sequences,
|
num_return_sequences=num_return_sequences,
|
||||||
output_scores=output_scores,
|
output_scores=output_scores,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict_in_generate=return_dict_in_generate,
|
return_dict_in_generate=return_dict_in_generate,
|
||||||
remove_invalid_values=True,
|
remove_invalid_values=True,
|
||||||
**logits_warper_kwargs,
|
|
||||||
**process_kwargs,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1389,25 +1369,21 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
input_length,
|
input_length,
|
||||||
forced_bos_token_id=None,
|
forced_bos_token_id=None,
|
||||||
forced_eos_token_id=None,
|
forced_eos_token_id=None,
|
||||||
max_length=None,
|
|
||||||
):
|
):
|
||||||
process_kwargs = {
|
process_kwargs = {}
|
||||||
"min_length": input_length + 1 if max_length is None else max_length - 1,
|
|
||||||
}
|
|
||||||
warper_kwargs = {}
|
warper_kwargs = {}
|
||||||
return process_kwargs, warper_kwargs
|
return process_kwargs, warper_kwargs
|
||||||
|
|
||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# disable cache
|
# disable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1421,7 +1397,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# enable cache
|
# enable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
|
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
@@ -1430,7 +1406,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1441,46 +1416,30 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
# check `generate()` and `sample()` are equal
|
# check `generate()` and `sample()` are equal
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=1,
|
num_return_sequences=1,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
)
|
)
|
||||||
self.assertIsInstance(output_generate, torch.Tensor)
|
self.assertIsInstance(output_generate, torch.Tensor)
|
||||||
|
|
||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
# disable cache
|
# disable cache
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = False
|
config.use_cache = False
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
process_kwargs, logits_warper_kwargs = self._get_logits_processor_and_warper_kwargs(
|
|
||||||
input_ids.shape[-1],
|
|
||||||
max_length=max_length,
|
|
||||||
)
|
|
||||||
|
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
num_return_sequences=3,
|
num_return_sequences=3,
|
||||||
logits_warper_kwargs=logits_warper_kwargs,
|
|
||||||
process_kwargs=process_kwargs,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -1490,7 +1449,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
config, _, _, max_length = self._get_input_ids_and_config()
|
config, _, _ = self._get_input_ids_and_config()
|
||||||
|
|
||||||
# if no bos token id => cannot generate from None
|
# if no bos token id => cannot generate from None
|
||||||
if config.bos_token_id is None:
|
if config.bos_token_id is None:
|
||||||
@@ -1500,7 +1459,9 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model = model_class(config).to(torch_device)
|
model = model_class(config).to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
output_ids_generate = model.generate(do_sample=False, max_length=max_length, remove_invalid_values=True)
|
output_ids_generate = model.generate(
|
||||||
|
do_sample=False, max_new_tokens=self.max_new_tokens, remove_invalid_values=True
|
||||||
|
)
|
||||||
self.assertIsNotNone(output_ids_generate)
|
self.assertIsNotNone(output_ids_generate)
|
||||||
|
|
||||||
@require_torch_fp16
|
@require_torch_fp16
|
||||||
@@ -1519,7 +1480,7 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
|
|
||||||
def test_greedy_generate_stereo_outputs(self):
|
def test_greedy_generate_stereo_outputs(self):
|
||||||
for model_class in self.greedy_sample_model_classes:
|
for model_class in self.greedy_sample_model_classes:
|
||||||
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.audio_channels = 2
|
config.audio_channels = 2
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
@@ -1527,7 +1488,6 @@ class MusicgenMelodyTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
|||||||
model=model,
|
model=model,
|
||||||
input_ids=input_ids.to(torch_device),
|
input_ids=input_ids.to(torch_device),
|
||||||
attention_mask=attention_mask.to(torch_device),
|
attention_mask=attention_mask.to(torch_device),
|
||||||
max_length=max_length,
|
|
||||||
output_scores=True,
|
output_scores=True,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
|
|||||||
@@ -686,6 +686,18 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
|
# override because overwise we hit max possible seq length for model (4*8=32)
|
||||||
|
# decreasing the seq_length in tester causes errors for "training_tests", those need exactly max seq length
|
||||||
|
# NOTE: seq_length has to be multiple of 4, otherwise it fails for other tests
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
input_ids = inputs_dict[self.input_name]
|
||||||
|
input_ids = input_ids[:batch_size, :16]
|
||||||
|
attention_mask = torch.ones_like(input_ids, dtype=torch.long)[:batch_size, :16]
|
||||||
|
config.eos_token_id = None
|
||||||
|
config.forced_eos_token_id = None
|
||||||
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ReformerLSHAttnModelTest(
|
class ReformerLSHAttnModelTest(
|
||||||
|
|||||||
@@ -285,7 +285,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
input_name = "input_features"
|
input_name = "input_features"
|
||||||
|
|
||||||
def _get_input_ids_and_config(self, batch_size=2):
|
def _get_input_ids_and_config(self, batch_size=2):
|
||||||
config, input_ids, attention_mask, max_length = GenerationTesterMixin._get_input_ids_and_config(self)
|
config, input_ids, attention_mask = GenerationTesterMixin._get_input_ids_and_config(self)
|
||||||
|
|
||||||
# `input_ids` is actually `input_features` which is a 3D tensor.
|
# `input_ids` is actually `input_features` which is a 3D tensor.
|
||||||
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
# We must overwrite the mask to make it 2D since the original `_get_input_ids_and_config` creates an
|
||||||
@@ -294,7 +294,7 @@ class Speech2TextModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTest
|
|||||||
sequence_length = input_ids.shape[1]
|
sequence_length = input_ids.shape[1]
|
||||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=attention_mask.device)
|
||||||
|
|
||||||
return config, input_ids, attention_mask, max_length
|
return config, input_ids, attention_mask
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = Speech2TextModelTester(self)
|
self.model_tester = Speech2TextModelTester(self)
|
||||||
|
|||||||
@@ -477,13 +477,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
# cut to half length & take max batch_size=batch_size
|
# cut to half length & take max batch_size=batch_size
|
||||||
input_ids = input_ids[:batch_size, :, :]
|
input_ids = input_ids[:batch_size, :, :]
|
||||||
|
|
||||||
# generate max 3 tokens
|
|
||||||
max_length = 4
|
|
||||||
if config.eos_token_id is not None and config.pad_token_id is None:
|
if config.eos_token_id is not None and config.pad_token_id is None:
|
||||||
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
# hack to allow generate for models such as GPT2 as is done in `generate()`
|
||||||
config.pad_token_id = config.eos_token_id
|
config.pad_token_id = config.eos_token_id
|
||||||
|
|
||||||
return config, input_ids, None, max_length
|
return config, input_ids, None
|
||||||
|
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
@@ -646,7 +646,8 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
seq_len = 1
|
seq_len = 1
|
||||||
else:
|
else:
|
||||||
# for first item dummy PAD token is appended so need one more
|
# for first item dummy PAD token is appended so need one more
|
||||||
seq_len = (min_length + 1) if idx == 0 else min_length
|
# else offset+dummy_token when using cache
|
||||||
|
seq_len = (min_length + 1) if idx == 0 else 3
|
||||||
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
||||||
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
||||||
@@ -665,8 +666,11 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
tgt_len = min_length
|
tgt_len = min_length
|
||||||
|
|
||||||
# for first item dummy PAD token is appended so need one more
|
# for first item dummy PAD token is appended so need one more
|
||||||
|
# every token after consists of offset+dummy_token length when using cache
|
||||||
if idx == 0:
|
if idx == 0:
|
||||||
tgt_len += 1
|
tgt_len += 1
|
||||||
|
else:
|
||||||
|
tgt_len = 3
|
||||||
|
|
||||||
src_len = min_length + idx + 1
|
src_len = min_length + idx + 1
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user