Mamba: add generative tests (#31478)
This commit is contained in:
@@ -102,7 +102,11 @@ class GenerationTesterMixin:
|
||||
if isinstance(config.eos_token_id, int):
|
||||
config.eos_token_id = [config.eos_token_id]
|
||||
config.pad_token_id = config.eos_token_id[0]
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
|
||||
if self.has_attentions:
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
else:
|
||||
attention_mask = None
|
||||
|
||||
# It is important set set the eos_token_id to None to ensure that no sequences
|
||||
# shorter than `max_length` can be generated
|
||||
@@ -437,7 +441,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@@ -471,7 +475,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@@ -529,7 +533,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@@ -595,7 +599,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -642,7 +646,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@@ -733,7 +737,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@@ -834,7 +838,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
if model.config.is_encoder_decoder:
|
||||
@@ -952,7 +956,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@@ -973,6 +977,9 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_contrastive_generate(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
@@ -997,6 +1004,9 @@ class GenerationTesterMixin:
|
||||
|
||||
def test_contrastive_generate_dict_outputs_use_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
|
||||
# won't fix: FSMT and Reformer have a different cache variable type (and format).
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
@@ -1017,7 +1027,7 @@ class GenerationTesterMixin:
|
||||
output_scores=True,
|
||||
output_logits=True,
|
||||
output_hidden_states=True,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
)
|
||||
|
||||
@@ -1030,9 +1040,12 @@ class GenerationTesterMixin:
|
||||
def test_contrastive_generate_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support contrastive search generation")
|
||||
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer", "speech2text"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode", "jamba"]):
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["gptbigcode"]):
|
||||
self.skipTest("TODO: fix me")
|
||||
|
||||
config, input_ids, attention_mask = self._get_input_ids_and_config(batch_size=1)
|
||||
@@ -1069,6 +1082,8 @@ class GenerationTesterMixin:
|
||||
def test_beam_search_low_memory(self):
|
||||
# Check that choosing 'low_memory' does not change the model output
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("May fix in the future: need custom cache handling")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@@ -1115,6 +1130,8 @@ class GenerationTesterMixin:
|
||||
# - assisted_decoding does not support `batch_size > 1`
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@@ -1156,7 +1173,7 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_greedy = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@@ -1184,6 +1201,8 @@ class GenerationTesterMixin:
|
||||
# This test is mostly a copy of test_assisted_decoding_matches_greedy_search
|
||||
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@@ -1225,7 +1244,7 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
|
||||
@@ -1244,6 +1263,8 @@ class GenerationTesterMixin:
|
||||
# match sample for the same seed, as the forward pass does not return the exact same logits (due to matmul with
|
||||
# different shapes, see https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535).
|
||||
for model_class in self.all_generative_model_classes:
|
||||
if model_class._is_stateful:
|
||||
self.skipTest("Stateful models don't support assisted generation")
|
||||
if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
|
||||
self.skipTest("Won't fix: old model with different cache format")
|
||||
if any(
|
||||
@@ -1289,7 +1310,7 @@ class GenerationTesterMixin:
|
||||
"output_scores": True,
|
||||
"output_logits": True,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"output_attentions": self.has_attentions,
|
||||
"return_dict_in_generate": True,
|
||||
}
|
||||
output_assisted = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||
@@ -1326,7 +1347,7 @@ class GenerationTesterMixin:
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
num_beams=1,
|
||||
output_attentions=True,
|
||||
output_attentions=self.has_attentions,
|
||||
return_dict_in_generate=True,
|
||||
remove_invalid_values=True,
|
||||
**{name: mask},
|
||||
@@ -1344,6 +1365,10 @@ class GenerationTesterMixin:
|
||||
if len(self.all_generative_model_classes) == 0:
|
||||
self.skipTest(reason="No generative architecture available for this model.")
|
||||
|
||||
# - The model must support padding
|
||||
if not self.has_attentions:
|
||||
self.skipTest(reason="This model doesn't support padding.")
|
||||
|
||||
# - The model must be a decoder-only architecture (encoder-based architectures use right-padding)
|
||||
decoder_only_classes = []
|
||||
for model_class in self.all_generative_model_classes:
|
||||
@@ -1704,30 +1729,31 @@ class GenerationTesterMixin:
|
||||
self._check_logits(num_sequences_in_output, output.logits, config=config)
|
||||
|
||||
# Attentions
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
attentions=attentions,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
if self.has_attentions:
|
||||
if config.is_encoder_decoder:
|
||||
# encoder
|
||||
self._check_encoder_attention_for_generate(output.encoder_attentions, batch_size, config, seq_length)
|
||||
# decoder
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
output.decoder_attentions,
|
||||
min_length=1,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
else:
|
||||
# if use_cache first input is equal to no use_cache, so skip here
|
||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
||||
min_length = seq_length if not use_cache else seq_length + 1
|
||||
self._check_attentions_for_generate(
|
||||
num_sequences_in_output,
|
||||
attentions=attentions,
|
||||
min_length=min_length,
|
||||
max_length=output.sequences.shape[-1],
|
||||
config=config,
|
||||
use_cache=use_cache,
|
||||
)
|
||||
|
||||
# Hidden States
|
||||
if config.is_encoder_decoder:
|
||||
@@ -1763,7 +1789,7 @@ class GenerationTesterMixin:
|
||||
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
|
||||
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
|
||||
# complete
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba")
|
||||
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
|
||||
has_standard_cache = not any(
|
||||
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user