[generate] shape checks in tests compatible with fixed-length caches (+ some minor fixes) (#35993)
* shape checks compatible with static cache * add test * tmp * manually turn on eager attn when we want to output attn * typo * generalize to encoder-decoder models * force compilation on cpu * tmp commit * fix static cache shape checks * models with odd caches * fix copies * shorter cache search loop * use decoder_past_key_values everywhere * better test variable names and comments * signature * rename _check_outputs into _check_generate_outputs * add comments * HybridCache future test note
This commit is contained in:
@@ -116,6 +116,16 @@ if is_accelerate_available():
|
|||||||
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
from accelerate.hooks import AlignDevicesHook, add_hook_to_module
|
||||||
|
|
||||||
|
|
||||||
|
# Variable names used to hold the cache at generation time
|
||||||
|
ALL_CACHE_NAMES = [
|
||||||
|
"past_key_values", # default
|
||||||
|
"cache_params", # mamba-based models
|
||||||
|
"state", # rwkv
|
||||||
|
"mems", # xlnet
|
||||||
|
"past_buckets_states", # reformer
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GenerateDecoderOnlyOutput(ModelOutput):
|
class GenerateDecoderOnlyOutput(ModelOutput):
|
||||||
"""
|
"""
|
||||||
@@ -756,21 +766,6 @@ class GenerationMixin:
|
|||||||
|
|
||||||
return input_ids, model_kwargs
|
return input_ids, model_kwargs
|
||||||
|
|
||||||
def _extract_past_from_model_output(self, outputs: ModelOutput):
|
|
||||||
past_key_values = None
|
|
||||||
cache_name = "past_key_values"
|
|
||||||
if "past_key_values" in outputs:
|
|
||||||
past_key_values = outputs.past_key_values
|
|
||||||
elif "mems" in outputs:
|
|
||||||
past_key_values = outputs.mems
|
|
||||||
elif "past_buckets_states" in outputs:
|
|
||||||
past_key_values = outputs.past_buckets_states
|
|
||||||
elif "cache_params" in outputs:
|
|
||||||
past_key_values = outputs.cache_params
|
|
||||||
cache_name = "cache_params"
|
|
||||||
|
|
||||||
return cache_name, past_key_values
|
|
||||||
|
|
||||||
def _update_model_kwargs_for_generation(
|
def _update_model_kwargs_for_generation(
|
||||||
self,
|
self,
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
@@ -779,10 +774,15 @@ class GenerationMixin:
|
|||||||
num_new_tokens: int = 1,
|
num_new_tokens: int = 1,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
# update past_key_values keeping its naming used in model code
|
# update past_key_values keeping its naming used in model code
|
||||||
cache_name, cache = self._extract_past_from_model_output(outputs)
|
for possible_cache_name in ALL_CACHE_NAMES:
|
||||||
model_kwargs[cache_name] = cache
|
if possible_cache_name in outputs:
|
||||||
if getattr(outputs, "state", None) is not None:
|
# TODO (joao): remove output/input mismatch when these old models (xlnet, reformer) are deprecated
|
||||||
model_kwargs["state"] = outputs.state
|
if possible_cache_name in ("past_buckets_states", "mems"):
|
||||||
|
cache_name = "past_key_values"
|
||||||
|
else:
|
||||||
|
cache_name = possible_cache_name
|
||||||
|
model_kwargs[cache_name] = getattr(outputs, possible_cache_name)
|
||||||
|
break
|
||||||
|
|
||||||
# update token_type_ids with last value
|
# update token_type_ids with last value
|
||||||
if "token_type_ids" in model_kwargs:
|
if "token_type_ids" in model_kwargs:
|
||||||
@@ -2087,7 +2087,7 @@ class GenerationMixin:
|
|||||||
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
# - `model_kwargs` may be updated in place with a cache as defined by the parameters in `generation_config`.
|
||||||
# - different models have a different cache name expected by the model (default = "past_key_values")
|
# - different models have a different cache name expected by the model (default = "past_key_values")
|
||||||
# - `max_length`, prepared above, is used to determine the maximum cache length
|
# - `max_length`, prepared above, is used to determine the maximum cache length
|
||||||
max_cache_length = generation_config.max_length
|
max_cache_length = generation_config.max_length - 1
|
||||||
if (
|
if (
|
||||||
inputs_tensor.shape[1] != input_ids_length
|
inputs_tensor.shape[1] != input_ids_length
|
||||||
and model_input_name == "inputs_embeds"
|
and model_input_name == "inputs_embeds"
|
||||||
@@ -2994,7 +2994,9 @@ class GenerationMixin:
|
|||||||
next_past_key_values = selected_outputs["past_key_values"]
|
next_past_key_values = selected_outputs["past_key_values"]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_, next_past_key_values = self._extract_past_from_model_output(outputs)
|
next_past_key_values = None
|
||||||
|
for possible_cache_name in ALL_CACHE_NAMES:
|
||||||
|
next_past_key_values = next_past_key_values or getattr(outputs, possible_cache_name, None)
|
||||||
# Do it in-place layer per layer to save memory
|
# Do it in-place layer per layer to save memory
|
||||||
if isinstance(next_past_key_values, DynamicCache) or (
|
if isinstance(next_past_key_values, DynamicCache) or (
|
||||||
isinstance(next_past_key_values, EncoderDecoderCache)
|
isinstance(next_past_key_values, EncoderDecoderCache)
|
||||||
|
|||||||
@@ -2557,32 +2557,3 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
|
|||||||
return outputs
|
return outputs
|
||||||
else:
|
else:
|
||||||
return output_values
|
return output_values
|
||||||
|
|
||||||
def _update_model_kwargs_for_generation(
|
|
||||||
self,
|
|
||||||
outputs: ModelOutput,
|
|
||||||
model_kwargs: Dict[str, Any],
|
|
||||||
is_encoder_decoder: bool = False,
|
|
||||||
model_inputs: Optional[Dict[str, Any]] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
# update past_key_values
|
|
||||||
cache_name, cache = self._extract_past_from_model_output(outputs)
|
|
||||||
model_kwargs[cache_name] = cache
|
|
||||||
|
|
||||||
if getattr(outputs, "state", None) is not None:
|
|
||||||
model_kwargs["state"] = outputs.state
|
|
||||||
|
|
||||||
# update token_type_ids with last value
|
|
||||||
if "token_type_ids" in model_kwargs:
|
|
||||||
token_type_ids = model_kwargs["token_type_ids"]
|
|
||||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
|
||||||
|
|
||||||
# update decoder attention mask
|
|
||||||
if "decoder_attention_mask" in model_kwargs:
|
|
||||||
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
||||||
model_kwargs["decoder_attention_mask"] = torch.cat(
|
|
||||||
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_kwargs
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import math
|
import math
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
@@ -1329,54 +1329,6 @@ class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMi
|
|||||||
)
|
)
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
|
||||||
def _update_model_kwargs_for_generation(
|
|
||||||
self,
|
|
||||||
outputs: ModelOutput,
|
|
||||||
model_kwargs: Dict[str, Any],
|
|
||||||
is_encoder_decoder: bool = False,
|
|
||||||
num_new_tokens: int = 1,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
# update past_key_values keeping its naming used in model code
|
|
||||||
cache_name, cache = self._extract_past_from_model_output(outputs)
|
|
||||||
model_kwargs[cache_name] = cache
|
|
||||||
if getattr(outputs, "state", None) is not None:
|
|
||||||
model_kwargs["state"] = outputs.state
|
|
||||||
|
|
||||||
# update attention_mask
|
|
||||||
if getattr(outputs, "attention_mask", None) is not None:
|
|
||||||
model_kwargs["attention_mask"] = outputs.attention_mask
|
|
||||||
|
|
||||||
# update token_type_ids with last value
|
|
||||||
if "token_type_ids" in model_kwargs:
|
|
||||||
token_type_ids = model_kwargs["token_type_ids"]
|
|
||||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
|
||||||
|
|
||||||
if not is_encoder_decoder:
|
|
||||||
# update attention mask
|
|
||||||
if "attention_mask" in model_kwargs:
|
|
||||||
attention_mask = model_kwargs["attention_mask"]
|
|
||||||
model_kwargs["attention_mask"] = torch.cat(
|
|
||||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# update decoder attention mask
|
|
||||||
if "decoder_attention_mask" in model_kwargs:
|
|
||||||
decoder_attention_mask = model_kwargs["decoder_attention_mask"]
|
|
||||||
model_kwargs["decoder_attention_mask"] = torch.cat(
|
|
||||||
[decoder_attention_mask, decoder_attention_mask.new_ones((decoder_attention_mask.shape[0], 1))],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_kwargs.get("use_cache", True):
|
|
||||||
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
|
|
||||||
else:
|
|
||||||
past_positions = model_kwargs.pop("cache_position")
|
|
||||||
new_positions = torch.arange(
|
|
||||||
past_positions[-1] + 1, past_positions[-1] + num_new_tokens + 1, dtype=past_positions.dtype
|
|
||||||
).to(past_positions.device)
|
|
||||||
model_kwargs["cache_position"] = torch.cat((past_positions, new_positions))
|
|
||||||
return model_kwargs
|
|
||||||
|
|
||||||
def _reorder_cache(self, *args, **kwargs):
|
def _reorder_cache(self, *args, **kwargs):
|
||||||
return self.language_model._reorder_cache(*args, **kwargs)
|
return self.language_model._reorder_cache(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -72,7 +72,14 @@ if is_torch_available():
|
|||||||
SpeechEncoderDecoderModel,
|
SpeechEncoderDecoderModel,
|
||||||
T5ForConditionalGeneration,
|
T5ForConditionalGeneration,
|
||||||
)
|
)
|
||||||
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
|
from transformers.cache_utils import (
|
||||||
|
Cache,
|
||||||
|
DynamicCache,
|
||||||
|
EncoderDecoderCache,
|
||||||
|
HybridCache,
|
||||||
|
QuantoQuantizedCache,
|
||||||
|
StaticCache,
|
||||||
|
)
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
BeamSampleDecoderOnlyOutput,
|
BeamSampleDecoderOnlyOutput,
|
||||||
BeamSampleEncoderDecoderOutput,
|
BeamSampleEncoderDecoderOutput,
|
||||||
@@ -473,6 +480,8 @@ class GenerationTesterMixin:
|
|||||||
def test_greedy_generate_dict_outputs(self):
|
def test_greedy_generate_dict_outputs(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._greedy_generate(
|
output_generate = self._greedy_generate(
|
||||||
@@ -499,12 +508,14 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, GreedySearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(output_generate, model.config)
|
self._check_generate_outputs(output_generate, model.config)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_greedy_generate_dict_outputs_use_cache(self):
|
def test_greedy_generate_dict_outputs_use_cache(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
@@ -531,7 +542,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
self._check_outputs(output_generate, model.config, use_cache=True)
|
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_sample_generate(self):
|
def test_sample_generate(self):
|
||||||
@@ -550,6 +561,8 @@ class GenerationTesterMixin:
|
|||||||
def test_sample_generate_dict_output(self):
|
def test_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._sample_generate(
|
output_generate = self._sample_generate(
|
||||||
@@ -577,7 +590,7 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, SampleDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(output_generate, model.config, num_return_sequences=2)
|
self._check_generate_outputs(output_generate, model.config, num_return_sequences=2)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_beam_search_generate(self):
|
def test_beam_search_generate(self):
|
||||||
@@ -598,6 +611,8 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_search_generate_dict_output(self):
|
def test_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@@ -625,7 +640,7 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_generate_outputs(
|
||||||
output_generate,
|
output_generate,
|
||||||
model.config,
|
model.config,
|
||||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
@@ -642,6 +657,8 @@ class GenerationTesterMixin:
|
|||||||
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
|
||||||
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
self.skipTest(reason="Won't fix: model with non-standard dictionary output shapes")
|
||||||
|
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
|
|
||||||
@@ -666,7 +683,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_generate_outputs(
|
||||||
output_generate,
|
output_generate,
|
||||||
model.config,
|
model.config,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
@@ -721,6 +738,8 @@ class GenerationTesterMixin:
|
|||||||
def test_beam_sample_generate_dict_output(self):
|
def test_beam_sample_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_beam_kwargs()
|
beam_kwargs = self._get_beam_kwargs()
|
||||||
@@ -750,7 +769,7 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_generate_outputs(
|
||||||
output_generate,
|
output_generate,
|
||||||
model.config,
|
model.config,
|
||||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
@@ -813,6 +832,8 @@ class GenerationTesterMixin:
|
|||||||
def test_group_beam_search_generate_dict_output(self):
|
def test_group_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
beam_kwargs = self._get_diverse_beam_kwargs()
|
beam_kwargs = self._get_diverse_beam_kwargs()
|
||||||
@@ -840,7 +861,7 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_generate_outputs(
|
||||||
output_generate,
|
output_generate,
|
||||||
model.config,
|
model.config,
|
||||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
@@ -909,6 +930,8 @@ class GenerationTesterMixin:
|
|||||||
def test_constrained_beam_search_generate_dict_output(self):
|
def test_constrained_beam_search_generate_dict_output(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
@@ -947,7 +970,7 @@ class GenerationTesterMixin:
|
|||||||
# Retrocompatibility check
|
# Retrocompatibility check
|
||||||
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput)
|
||||||
|
|
||||||
self._check_outputs(
|
self._check_generate_outputs(
|
||||||
output_generate,
|
output_generate,
|
||||||
model.config,
|
model.config,
|
||||||
num_return_sequences=beam_kwargs["num_return_sequences"],
|
num_return_sequences=beam_kwargs["num_return_sequences"],
|
||||||
@@ -999,6 +1022,8 @@ class GenerationTesterMixin:
|
|||||||
if not hasattr(config, "use_cache"):
|
if not hasattr(config, "use_cache"):
|
||||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
output_generate = self._contrastive_generate(
|
output_generate = self._contrastive_generate(
|
||||||
@@ -1019,7 +1044,7 @@ class GenerationTesterMixin:
|
|||||||
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
)
|
)
|
||||||
|
|
||||||
self._check_outputs(output_generate, model.config, use_cache=True)
|
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_contrastive_generate_low_memory(self):
|
def test_contrastive_generate_low_memory(self):
|
||||||
@@ -1205,7 +1230,7 @@ class GenerationTesterMixin:
|
|||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
self._check_similar_generate_outputs(output_greedy, output_assisted)
|
||||||
for output in (output_greedy, output_assisted):
|
for output in (output_greedy, output_assisted):
|
||||||
self._check_outputs(output, model.config, use_cache=True)
|
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
@@ -1270,7 +1295,7 @@ class GenerationTesterMixin:
|
|||||||
# The two outputs must match and their shape must be as expected
|
# The two outputs must match and their shape must be as expected
|
||||||
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
|
self._check_similar_generate_outputs(output_greedy, output_prompt_lookup)
|
||||||
for output in (output_greedy, output_prompt_lookup):
|
for output in (output_greedy, output_prompt_lookup):
|
||||||
self._check_outputs(output, model.config, use_cache=True)
|
self._check_generate_outputs(output, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_dola_decoding_sample(self):
|
def test_dola_decoding_sample(self):
|
||||||
@@ -1320,7 +1345,7 @@ class GenerationTesterMixin:
|
|||||||
"dola_layers": "low",
|
"dola_layers": "low",
|
||||||
}
|
}
|
||||||
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
|
output_dola = model.generate(**generation_kwargs, **logits_processor_kwargs, **inputs_dict)
|
||||||
self._check_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
self._check_generate_outputs(output_dola, model.config, use_cache=getattr(config, "use_cache", False))
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
@@ -1381,7 +1406,7 @@ class GenerationTesterMixin:
|
|||||||
}
|
}
|
||||||
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
output_assisted = model.generate(**generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
self._check_outputs(output_assisted, config, use_cache=True)
|
self._check_generate_outputs(output_assisted, config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_prompt_lookup_decoding_stops_at_eos(self):
|
def test_prompt_lookup_decoding_stops_at_eos(self):
|
||||||
@@ -1419,6 +1444,8 @@ class GenerationTesterMixin:
|
|||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
text_config = config.get_text_config()
|
text_config = config.get_text_config()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
|
||||||
# We want to test only encoder-decoder models
|
# We want to test only encoder-decoder models
|
||||||
if not text_config.is_encoder_decoder:
|
if not text_config.is_encoder_decoder:
|
||||||
@@ -1765,7 +1792,7 @@ class GenerationTesterMixin:
|
|||||||
num_hidden_layers = text_config.num_hidden_layers
|
num_hidden_layers = text_config.num_hidden_layers
|
||||||
|
|
||||||
inputs_embeds = model.get_input_embeddings()(input_ids)
|
inputs_embeds = model.get_input_embeddings()(input_ids)
|
||||||
max_cache_len += inputs_embeds.shape[1]
|
max_cache_len += inputs_embeds.shape[1] - 1 # the last generated token has no cache
|
||||||
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
|
outputs = model.generate(inputs_embeds=inputs_embeds, **generation_kwargs, **inputs_dict)
|
||||||
|
|
||||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
# we should get `max_length` in shape, not `max_length - embeds_length`
|
||||||
@@ -2003,7 +2030,7 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check 1: The cache shapes must match the expected shapes
|
# Check 1: The cache shapes must match the expected shapes
|
||||||
max_cache_len = seq_length + max_new_tokens
|
max_cache_len = seq_length + max_new_tokens - 1 # cache len = gen len - 1, the last token has no cache
|
||||||
text_config = config.text_config if hasattr(config, "text_config") else config
|
text_config = config.text_config if hasattr(config, "text_config") else config
|
||||||
head_dim = (
|
head_dim = (
|
||||||
text_config.head_dim
|
text_config.head_dim
|
||||||
@@ -2138,6 +2165,58 @@ class GenerationTesterMixin:
|
|||||||
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
for dynamic_result, compiled_result in zip(dynamic_outputs, compiled_outputs):
|
||||||
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
self._check_similar_generate_outputs(dynamic_result, compiled_result)
|
||||||
|
|
||||||
|
@pytest.mark.generate
|
||||||
|
def test_generate_compilation_all_outputs(self):
|
||||||
|
"""
|
||||||
|
Tests that all optional outputs are behaving as expected when compilation is triggered.
|
||||||
|
In essence, it's the same as `test_greedy_generate_dict_outputs`, but with automatic compilation triggered.
|
||||||
|
"""
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
if not model_class._supports_static_cache:
|
||||||
|
self.skipTest("This model doesn't support static cache (= no expectations of compilation support)")
|
||||||
|
|
||||||
|
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||||
|
if self.has_attentions:
|
||||||
|
config._attn_implementation = "eager" # can't output attentions otherwise
|
||||||
|
model = model_class(config).to(torch_device).eval()
|
||||||
|
|
||||||
|
# compilation-specific setup
|
||||||
|
torch.compiler.reset() # prevent cached compilation from being used in the test
|
||||||
|
has_defined_cache_implementation = model.generation_config.cache_implementation is not None
|
||||||
|
model.generation_config.compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU)
|
||||||
|
if not has_defined_cache_implementation:
|
||||||
|
model.generation_config.cache_implementation = "static"
|
||||||
|
|
||||||
|
logits_processor_kwargs = self._get_logits_processor_kwargs(do_sample=False, config=model.config)
|
||||||
|
output_generate = model.generate(
|
||||||
|
do_sample=False,
|
||||||
|
num_beams=1,
|
||||||
|
max_new_tokens=self.max_new_tokens,
|
||||||
|
min_new_tokens=self.max_new_tokens,
|
||||||
|
output_attentions=True,
|
||||||
|
output_hidden_states=True,
|
||||||
|
output_scores=True,
|
||||||
|
output_logits=True,
|
||||||
|
return_dict_in_generate=True,
|
||||||
|
use_cache=True,
|
||||||
|
**logits_processor_kwargs,
|
||||||
|
**inputs_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanity check: compilation has happened
|
||||||
|
self.assertTrue(hasattr(model, "_compiled_call"))
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
self.assertTrue(output_generate.sequences.shape[-1] == self.max_new_tokens + 1)
|
||||||
|
self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
|
||||||
|
else:
|
||||||
|
self.assertTrue(
|
||||||
|
output_generate.sequences.shape[-1] == self.max_new_tokens + inputs_dict["input_ids"].shape[-1]
|
||||||
|
)
|
||||||
|
self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
|
||||||
|
|
||||||
|
self._check_generate_outputs(output_generate, model.config, use_cache=True)
|
||||||
|
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_generate_methods_with_logits_to_keep(self):
|
def test_generate_methods_with_logits_to_keep(self):
|
||||||
for model_class in self.all_generative_model_classes:
|
for model_class in self.all_generative_model_classes:
|
||||||
@@ -2290,86 +2369,90 @@ class GenerationTesterMixin:
|
|||||||
# check whether we still need the overwrites
|
# check whether we still need the overwrites
|
||||||
self._test_attention_implementation("flash_attention_2")
|
self._test_attention_implementation("flash_attention_2")
|
||||||
|
|
||||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
||||||
internal_batch_size = (
|
internal_batch_size = (
|
||||||
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
prompt_length = getattr(self.model_tester, "seq_length", None)
|
||||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
prompt_length = getattr(self.model_tester, "encoder_seq_length", prompt_length)
|
||||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
prompt_length = getattr(self.model_tester, "text_seq_length", prompt_length)
|
||||||
|
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
config = config.text_config if hasattr(config, "text_config") else config
|
||||||
|
|
||||||
gen_len = (
|
generated_length = (
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - prompt_length
|
||||||
)
|
)
|
||||||
|
decoder_past_key_values = getattr(output, "past_key_values", None)
|
||||||
|
if config.is_encoder_decoder and isinstance(decoder_past_key_values, EncoderDecoderCache):
|
||||||
|
decoder_past_key_values = decoder_past_key_values.self_attention_cache
|
||||||
|
|
||||||
# in some models we subsample the sequence length in inner layers
|
# in some models we subsample the sequence length in inner layers
|
||||||
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
|
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
|
||||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
prompt_length = self.model_tester.get_subsampled_output_lengths(prompt_length)
|
||||||
|
|
||||||
# scores
|
# scores
|
||||||
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
self._check_scores(
|
||||||
|
batch_size=internal_batch_size, scores=output.scores, generated_length=generated_length, config=config
|
||||||
|
)
|
||||||
|
|
||||||
# unprocessed logits
|
# unprocessed logits
|
||||||
self._check_logits(internal_batch_size, output.logits, config=config)
|
self._check_logits(batch_size=internal_batch_size, logits=output.logits, config=config)
|
||||||
|
|
||||||
# Attentions
|
# Attentions
|
||||||
if self.has_attentions:
|
if self.has_attentions:
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# encoder
|
# encoder
|
||||||
self._check_encoder_attention_for_generate(
|
self._check_encoder_attention_for_generate(
|
||||||
output.encoder_attentions, input_batch_size, config, seq_length
|
attentions=output.encoder_attentions,
|
||||||
|
batch_size=input_batch_size,
|
||||||
|
config=config,
|
||||||
|
prompt_length=prompt_length,
|
||||||
)
|
)
|
||||||
# decoder
|
# decoder
|
||||||
self._check_attentions_for_generate(
|
self._check_attentions_for_generate(
|
||||||
internal_batch_size,
|
batch_size=internal_batch_size,
|
||||||
output.decoder_attentions,
|
attentions=output.decoder_attentions,
|
||||||
min_length=1,
|
prompt_length=1, # the BOS token
|
||||||
max_length=output.sequences.shape[-1],
|
output_length=output.sequences.shape[-1],
|
||||||
config=config,
|
config=config,
|
||||||
use_cache=use_cache,
|
decoder_past_key_values=decoder_past_key_values,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_attentions_for_generate(
|
self._check_attentions_for_generate(
|
||||||
internal_batch_size,
|
batch_size=internal_batch_size,
|
||||||
attentions=attentions,
|
attentions=output.attentions,
|
||||||
min_length=min_length,
|
prompt_length=prompt_length,
|
||||||
max_length=output.sequences.shape[-1],
|
output_length=output.sequences.shape[-1],
|
||||||
config=config,
|
config=config,
|
||||||
use_cache=use_cache,
|
decoder_past_key_values=decoder_past_key_values,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Hidden States
|
# Hidden States
|
||||||
if config.is_encoder_decoder:
|
if config.is_encoder_decoder:
|
||||||
# encoder
|
# encoder
|
||||||
self._check_encoder_hidden_states_for_generate(
|
self._check_encoder_hidden_states_for_generate(
|
||||||
output.encoder_hidden_states, input_batch_size, config, seq_length
|
hidden_states=output.encoder_hidden_states,
|
||||||
|
batch_size=input_batch_size,
|
||||||
|
config=config,
|
||||||
|
prompt_length=prompt_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
# decoder
|
# decoder
|
||||||
self._check_hidden_states_for_generate(
|
self._check_hidden_states_for_generate(
|
||||||
internal_batch_size,
|
batch_size=internal_batch_size,
|
||||||
output.decoder_hidden_states,
|
hidden_states=output.decoder_hidden_states,
|
||||||
min_length=1,
|
prompt_length=1, # the BOS token
|
||||||
max_length=output.sequences.shape[-1],
|
output_length=output.sequences.shape[-1],
|
||||||
config=config,
|
config=config,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_hidden_states_for_generate(
|
self._check_hidden_states_for_generate(
|
||||||
internal_batch_size,
|
batch_size=internal_batch_size,
|
||||||
hidden_states,
|
hidden_states=output.hidden_states,
|
||||||
min_length=min_length,
|
prompt_length=prompt_length,
|
||||||
max_length=output.sequences.shape[-1],
|
output_length=output.sequences.shape[-1],
|
||||||
config=config,
|
config=config,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
)
|
)
|
||||||
@@ -2396,59 +2479,73 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
if has_standard_cache:
|
if has_standard_cache:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
past_key_values = output.past_key_values
|
cache_length = output.sequences.shape[-1] - 1
|
||||||
past_sequence_length = output.sequences.shape[-1] - 1
|
|
||||||
self._check_past_key_values_for_generate(
|
self._check_past_key_values_for_generate(
|
||||||
internal_batch_size,
|
batch_size=internal_batch_size,
|
||||||
past_key_values,
|
decoder_past_key_values=decoder_past_key_values,
|
||||||
seq_length=past_sequence_length,
|
cache_length=cache_length,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
elif use_cache is False:
|
elif use_cache is False:
|
||||||
self.assertTrue(output.past_key_values is None)
|
self.assertTrue(decoder_past_key_values is None)
|
||||||
|
|
||||||
def _check_scores(self, batch_size, scores, length, config):
|
def _check_scores(self, batch_size, scores, generated_length, config):
|
||||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||||
expected_shape = (batch_size, vocab_size)
|
expected_shape = (batch_size, vocab_size)
|
||||||
self.assertIsInstance(scores, tuple)
|
self.assertIsInstance(scores, tuple)
|
||||||
self.assertEqual(len(scores), length)
|
self.assertEqual(len(scores), generated_length)
|
||||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||||
|
|
||||||
def _check_logits(self, batch_size, scores, config):
|
def _check_logits(self, batch_size, logits, config):
|
||||||
vocab_size = config.get_text_config(decoder=True).vocab_size
|
vocab_size = config.get_text_config(decoder=True).vocab_size
|
||||||
self.assertIsInstance(scores, tuple)
|
self.assertIsInstance(logits, tuple)
|
||||||
self.assertListEqual([iter_scores.shape[0] for iter_scores in scores], [batch_size] * len(scores))
|
self.assertListEqual([iter_logits.shape[0] for iter_logits in logits], [batch_size] * len(logits))
|
||||||
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
# vocabulary difference equal to one (imagegptmodel?) or zero (all other models)
|
||||||
vocab_diff = vocab_size - scores[0].shape[-1]
|
vocab_diff = vocab_size - logits[0].shape[-1]
|
||||||
self.assertTrue(vocab_diff in [0, 1])
|
self.assertTrue(vocab_diff in [0, 1])
|
||||||
self.assertListEqual([vocab_size - score.shape[-1] for score in scores], [vocab_diff] * len(scores))
|
self.assertListEqual([vocab_size - score.shape[-1] for score in logits], [vocab_diff] * len(logits))
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
use_cache = decoder_past_key_values is not None
|
||||||
tgt_len = min_length + idx if not use_cache else 1
|
has_static_cache = isinstance(decoder_past_key_values, (StaticCache, HybridCache))
|
||||||
src_len = min_length + idx
|
|
||||||
|
# When `output_attentions=True`, each iteration of generate appends the attentions corresponding to the new
|
||||||
|
# token(s)
|
||||||
|
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
|
||||||
|
# elaborate checks
|
||||||
|
for generated_length, iter_attentions in enumerate(attentions):
|
||||||
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||||
|
if use_cache and generated_length > 0:
|
||||||
|
model_input_length = 1
|
||||||
|
else:
|
||||||
|
model_input_length = prompt_length + generated_length
|
||||||
|
query_length = (
|
||||||
|
prompt_length + generated_length
|
||||||
|
if not has_static_cache
|
||||||
|
else decoder_past_key_values.get_max_cache_shape()
|
||||||
|
)
|
||||||
|
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
tgt_len,
|
model_input_length,
|
||||||
src_len,
|
query_length,
|
||||||
)
|
)
|
||||||
# check attn size
|
# check attn size
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length)
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_attentions.shape for layer_attentions in attentions],
|
[layer_attentions.shape for layer_attentions in attentions],
|
||||||
@@ -2456,71 +2553,75 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||||
):
|
):
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||||
[True] * len(hidden_states),
|
[True] * len(hidden_states),
|
||||||
)
|
)
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
# When `output_hidden_states=True`, each iteration of generate appends the hidden states corresponding to the
|
||||||
seq_len = min_length + idx if not use_cache else 1
|
# new token(s)
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
# NOTE: `HybridCache` may have different lengths on different layers, if this test starts failing add more
|
||||||
|
# elaborate checks
|
||||||
|
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
||||||
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||||
|
if use_cache and generated_length > 0:
|
||||||
|
model_input_length = 1
|
||||||
|
else:
|
||||||
|
model_input_length = prompt_length + generated_length
|
||||||
|
expected_shape = (batch_size, model_input_length, config.hidden_size)
|
||||||
# check hidden size
|
# check hidden size
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||||
[expected_shape] * len(iter_hidden_states),
|
[expected_shape] * len(iter_hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
||||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
encoder_expected_shape = (batch_size, prompt_length, config.hidden_size)
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||||
[encoder_expected_shape] * len(hidden_states),
|
[encoder_expected_shape] * len(hidden_states),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
def _check_past_key_values_for_generate(self, batch_size, decoder_past_key_values, cache_length, config):
|
||||||
self.assertIsInstance(past_key_values, (tuple, Cache))
|
self.assertIsInstance(decoder_past_key_values, (tuple, Cache))
|
||||||
|
|
||||||
# Encoder-decoder models: pull and verify the decoder cache
|
|
||||||
if isinstance(past_key_values, EncoderDecoderCache):
|
|
||||||
past_key_values = past_key_values.self_attention_cache
|
|
||||||
|
|
||||||
# (batch, head, seq_length, head_features)
|
# (batch, head, seq_length, head_features)
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
config.num_key_value_heads if hasattr(config, "num_key_value_heads") else config.num_attention_heads,
|
||||||
seq_length,
|
cache_length,
|
||||||
config.hidden_size // config.num_attention_heads,
|
config.hidden_size // config.num_attention_heads,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(past_key_values, Cache):
|
if isinstance(decoder_past_key_values, Cache):
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[key_tensor.shape for key_tensor in past_key_values.key_cache],
|
[key_tensor.shape for key_tensor in decoder_past_key_values.key_cache],
|
||||||
[expected_shape] * len(past_key_values.key_cache),
|
[expected_shape] * len(decoder_past_key_values.key_cache),
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[value_tensor.shape for value_tensor in past_key_values.value_cache],
|
[value_tensor.shape for value_tensor in decoder_past_key_values.value_cache],
|
||||||
[expected_shape] * len(past_key_values.value_cache),
|
[expected_shape] * len(decoder_past_key_values.value_cache),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
# Legacy cache format checks. This branch should be removed when all models use `Cache` by default
|
||||||
else:
|
else:
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in past_key_values],
|
[isinstance(iter_past_key_values, tuple) for iter_past_key_values in decoder_past_key_values],
|
||||||
[True] * len(past_key_values),
|
[True] * len(decoder_past_key_values),
|
||||||
)
|
)
|
||||||
# check shape key, value
|
# check shape key, value
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_past_key_values[0].shape for layer_past_key_values in past_key_values],
|
[layer_past_key_values[0].shape for layer_past_key_values in decoder_past_key_values],
|
||||||
[expected_shape] * len(past_key_values),
|
[expected_shape] * len(decoder_past_key_values),
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_past_key_values[1].shape for layer_past_key_values in past_key_values],
|
[layer_past_key_values[1].shape for layer_past_key_values in decoder_past_key_values],
|
||||||
[expected_shape] * len(past_key_values),
|
[expected_shape] * len(decoder_past_key_values),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
def _check_sequence_inside_sequence(self, tensor_1, tensor_2):
|
||||||
|
|||||||
@@ -723,103 +723,12 @@ class Blip2ForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, GenerationT
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
# overwrite because BLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format
|
# overwrite because BLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format
|
||||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
use_cache = True # force this to be True in case False is passed
|
use_cache = True # force this to be True in case False is passed
|
||||||
|
super()._check_generate_outputs(
|
||||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
|
||||||
internal_batch_size = (
|
|
||||||
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
|
||||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
|
||||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
|
||||||
|
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
|
||||||
|
|
||||||
gen_len = (
|
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# in some models we subsample the sequence length in inner layers
|
|
||||||
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
|
|
||||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
|
||||||
|
|
||||||
# scores
|
|
||||||
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
|
||||||
|
|
||||||
# unprocessed logits
|
|
||||||
self._check_logits(internal_batch_size, output.logits, config=config)
|
|
||||||
|
|
||||||
# Attentions
|
|
||||||
if self.has_attentions:
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_attention_for_generate(
|
|
||||||
output.encoder_attentions, input_batch_size, config, seq_length
|
|
||||||
)
|
|
||||||
# decoder
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
output.decoder_attentions,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
attentions=attentions,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_hidden_states_for_generate(
|
|
||||||
output.encoder_hidden_states, input_batch_size, config, seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
output.decoder_hidden_states,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
hidden_states,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Past Key Value States
|
|
||||||
if use_cache:
|
|
||||||
past_key_values = output.past_key_values
|
|
||||||
past_sequence_length = output.sequences.shape[-1] - 1
|
|
||||||
self._check_past_key_values_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
past_key_values,
|
|
||||||
seq_length=past_sequence_length,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwrite because BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present
|
# overwrite because BLIP2 cannot generate only from input ids, and requires pixel values in all cases to be present
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from packaging import version
|
|||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, HybridCache, is_torch_available, pipeline
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Cohere2Config, is_torch_available, pipeline
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -135,51 +135,6 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
|
|||||||
def test_generate_continue_from_inputs_embeds(self):
|
def test_generate_continue_from_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# overwrite because HybridCache has fixed length for key/values
|
|
||||||
def _check_attentions_for_generate(
|
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
||||||
):
|
|
||||||
self.assertIsInstance(attentions, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
|
||||||
)
|
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
|
||||||
tgt_len = min_length + idx if not use_cache else 1
|
|
||||||
src_len = min_length + idx if not use_cache else max_length
|
|
||||||
|
|
||||||
expected_shape = (
|
|
||||||
batch_size * num_beam_groups,
|
|
||||||
config.num_attention_heads,
|
|
||||||
tgt_len,
|
|
||||||
src_len,
|
|
||||||
)
|
|
||||||
# check attn size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwrite because HybridCache has fixed length for key/values
|
|
||||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
|
||||||
self.assertIsInstance(past_key_values, HybridCache)
|
|
||||||
|
|
||||||
# check shape key, value (batch, head, max_seq_length, head_features)
|
|
||||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
||||||
num_key_value_heads = (
|
|
||||||
config.num_attention_heads
|
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
|
||||||
else config.num_key_value_heads
|
|
||||||
)
|
|
||||||
num_hidden_layers = config.num_hidden_layers
|
|
||||||
|
|
||||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
|
||||||
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
|
|
||||||
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
|
|
||||||
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
|
|
||||||
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
|
|
||||||
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
|
|
||||||
|
|
||||||
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
|
@unittest.skip("Cohere2's eager attn/sdpa attn outputs are expected to be different")
|
||||||
def test_sdpa_equivalence(self):
|
def test_sdpa_equivalence(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from packaging import version
|
|||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from pytest import mark
|
from pytest import mark
|
||||||
|
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
|
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, is_torch_available, pipeline
|
||||||
from transformers.generation.configuration_utils import GenerationConfig
|
from transformers.generation.configuration_utils import GenerationConfig
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
@@ -150,51 +150,6 @@ class Gemma2ModelTest(GemmaModelTest, unittest.TestCase):
|
|||||||
def test_generate_continue_from_inputs_embeds(self):
|
def test_generate_continue_from_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# overwrite because HybridCache has fixed length for key/values
|
|
||||||
def _check_attentions_for_generate(
|
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
||||||
):
|
|
||||||
self.assertIsInstance(attentions, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
|
||||||
)
|
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
|
||||||
tgt_len = min_length + idx if not use_cache else 1
|
|
||||||
src_len = min_length + idx if not use_cache else max_length
|
|
||||||
|
|
||||||
expected_shape = (
|
|
||||||
batch_size * num_beam_groups,
|
|
||||||
config.num_attention_heads,
|
|
||||||
tgt_len,
|
|
||||||
src_len,
|
|
||||||
)
|
|
||||||
# check attn size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwrite because HybridCache has fixed length for key/values
|
|
||||||
def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config, num_beam_groups=1):
|
|
||||||
self.assertIsInstance(past_key_values, HybridCache)
|
|
||||||
|
|
||||||
# check shape key, value (batch, head, max_seq_length, head_features)
|
|
||||||
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
|
||||||
num_key_value_heads = (
|
|
||||||
config.num_attention_heads
|
|
||||||
if getattr(config, "num_key_value_heads", None) is None
|
|
||||||
else config.num_key_value_heads
|
|
||||||
)
|
|
||||||
num_hidden_layers = config.num_hidden_layers
|
|
||||||
|
|
||||||
# we should get `max_length` in shape, not `max_length - embeds_length`
|
|
||||||
# `+1` because the test in Mixin subtracts 1 which is needed for tuple cache
|
|
||||||
static_cache_shape = (batch_size, num_key_value_heads, seq_length + 1, head_dim)
|
|
||||||
static_layers = [layer_idx for layer_idx, boolean in enumerate(past_key_values.is_sliding) if not boolean]
|
|
||||||
self.assertTrue(len(past_key_values.key_cache) == num_hidden_layers)
|
|
||||||
self.assertTrue(past_key_values.key_cache[static_layers[0]].shape == static_cache_shape)
|
|
||||||
|
|
||||||
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
@unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different")
|
||||||
def test_sdpa_equivalence(self):
|
def test_sdpa_equivalence(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -456,51 +456,26 @@ class GitModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
# GIT attention shape depends on image inputs, overwrite
|
# GIT attention shape depends on image inputs, overwrite
|
||||||
self.assertIsInstance(attentions, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
|
||||||
)
|
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
|
||||||
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
||||||
|
prompt_length += image_length
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
output_length += image_length
|
||||||
tgt_len = min_length + idx + image_length if not use_cache else 1
|
super()._check_attentions_for_generate(
|
||||||
src_len = min_length + idx + image_length
|
batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
|
)
|
||||||
expected_shape = (
|
|
||||||
batch_size * num_beam_groups,
|
|
||||||
config.num_attention_heads,
|
|
||||||
tgt_len,
|
|
||||||
src_len,
|
|
||||||
)
|
|
||||||
# check attn size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||||
):
|
):
|
||||||
# GIT attention shape depends on image inputs, overwrite
|
# GIT attention shape depends on image inputs, overwrite
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
|
||||||
[True] * len(hidden_states),
|
|
||||||
)
|
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
|
||||||
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
image_length = int((config.vision_config.image_size / config.vision_config.patch_size) ** 2 + 1)
|
||||||
|
prompt_length += image_length
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
output_length += image_length
|
||||||
seq_len = min_length + idx + image_length if not use_cache else 1
|
super()._check_hidden_states_for_generate(
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
batch_size, hidden_states, prompt_length, output_length, config, use_cache=use_cache
|
||||||
# check hidden size
|
)
|
||||||
self.assertListEqual(
|
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
|
||||||
[expected_shape] * len(iter_hidden_states),
|
|
||||||
)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
|
|||||||
@@ -815,7 +815,7 @@ class IdeficsForVisionText2TextTest(IdeficsModelTest, GenerationTesterMixin, uni
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Overwrite from generation tests because Idefics has only SDPA layers.
|
Overwrite from generation tests because Idefics has only SDPA layers.
|
||||||
|
|||||||
@@ -251,10 +251,10 @@ class ImageGPTModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
return inputs_dict
|
return inputs_dict
|
||||||
|
|
||||||
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings
|
# we overwrite the _check_scores method of GenerationTesterMixin, as ImageGPTForCausalImageModeling doesn't have tied input- and output embeddings
|
||||||
def _check_scores(self, batch_size, scores, length, config):
|
def _check_scores(self, batch_size, scores, generated_length, config):
|
||||||
expected_shape = (batch_size, config.vocab_size - 1)
|
expected_shape = (batch_size, config.vocab_size - 1)
|
||||||
self.assertIsInstance(scores, tuple)
|
self.assertIsInstance(scores, tuple)
|
||||||
self.assertEqual(len(scores), length)
|
self.assertEqual(len(scores), generated_length)
|
||||||
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
self.assertListEqual([iter_scores.shape for iter_scores in scores], [expected_shape] * len(scores))
|
||||||
|
|
||||||
@run_test_using_subprocess
|
@run_test_using_subprocess
|
||||||
|
|||||||
@@ -565,103 +565,12 @@ class InstructBlipForConditionalGenerationDecoderOnlyTest(ModelTesterMixin, Gene
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
# overwrite because InstructBLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format
|
# overwrite because InstructBLIP internally calls LM.generate() with embeds thus it cannot operate in no cache format
|
||||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
use_cache = True # force this to be True in case False is passed
|
use_cache = True # force this to be True in case False is passed
|
||||||
|
super()._check_generate_outputs(
|
||||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
|
||||||
internal_batch_size = (
|
|
||||||
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
|
||||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
|
||||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
|
||||||
|
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
|
||||||
|
|
||||||
gen_len = (
|
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# in some models we subsample the sequence length in inner layers
|
|
||||||
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
|
|
||||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
|
||||||
|
|
||||||
# scores
|
|
||||||
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
|
||||||
|
|
||||||
# unprocessed logits
|
|
||||||
self._check_logits(internal_batch_size, output.logits, config=config)
|
|
||||||
|
|
||||||
# Attentions
|
|
||||||
if self.has_attentions:
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_attention_for_generate(
|
|
||||||
output.encoder_attentions, input_batch_size, config, seq_length
|
|
||||||
)
|
|
||||||
# decoder
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
output.decoder_attentions,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
attentions=attentions,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_hidden_states_for_generate(
|
|
||||||
output.encoder_hidden_states, input_batch_size, config, seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
output.decoder_hidden_states,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
hidden_states,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Past Key Value States
|
|
||||||
if use_cache:
|
|
||||||
past_key_values = output.past_key_values
|
|
||||||
past_sequence_length = output.sequences.shape[-1] - 1
|
|
||||||
self._check_past_key_values_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
past_key_values,
|
|
||||||
seq_length=past_sequence_length,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwrite because InstructBLIP cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
|
# overwrite because InstructBLIP cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
|
|||||||
@@ -581,103 +581,12 @@ class InstructBlipVideoForConditionalGenerationDecoderOnlyTest(
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
# overwrite because InstructBLIPVideo internally calls LM.generate() with embeds thus it cannot operate in no cache format
|
# overwrite because InstructBLIPVideo internally calls LM.generate() with embeds thus it cannot operate in no cache format
|
||||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
use_cache = True # force this to be True in case False is passed
|
use_cache = True # force this to be True in case False is passed
|
||||||
|
super()._check_generate_outputs(
|
||||||
input_batch_size = int(output.sequences.shape[0] / num_return_sequences)
|
output, config, use_cache=use_cache, num_return_sequences=num_return_sequences, num_beams=num_beams
|
||||||
internal_batch_size = (
|
|
||||||
input_batch_size * num_beams if num_beams > 1 else input_batch_size * num_return_sequences
|
|
||||||
)
|
)
|
||||||
|
|
||||||
seq_length = getattr(self.model_tester, "seq_length", None)
|
|
||||||
seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
|
|
||||||
seq_length = getattr(self.model_tester, "text_seq_length", seq_length)
|
|
||||||
|
|
||||||
config = config.text_config if hasattr(config, "text_config") else config
|
|
||||||
|
|
||||||
gen_len = (
|
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# in some models we subsample the sequence length in inner layers
|
|
||||||
if hasattr(self.model_tester, "get_subsampled_output_lengths"):
|
|
||||||
seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
|
||||||
|
|
||||||
# scores
|
|
||||||
self._check_scores(internal_batch_size, output.scores, length=gen_len, config=config)
|
|
||||||
|
|
||||||
# unprocessed logits
|
|
||||||
self._check_logits(internal_batch_size, output.logits, config=config)
|
|
||||||
|
|
||||||
# Attentions
|
|
||||||
if self.has_attentions:
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_attention_for_generate(
|
|
||||||
output.encoder_attentions, input_batch_size, config, seq_length
|
|
||||||
)
|
|
||||||
# decoder
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
output.decoder_attentions,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
attentions = output.attentions if not use_cache else output.attentions[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
attentions=attentions,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
if config.is_encoder_decoder:
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_hidden_states_for_generate(
|
|
||||||
output.encoder_hidden_states, input_batch_size, config, seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
output.decoder_hidden_states,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# if use_cache first input is equal to no use_cache, so skip here
|
|
||||||
hidden_states = output.hidden_states if not use_cache else output.hidden_states[1:]
|
|
||||||
min_length = seq_length if not use_cache else seq_length + 1
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
hidden_states,
|
|
||||||
min_length=min_length,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Past Key Value States
|
|
||||||
if use_cache:
|
|
||||||
past_key_values = output.past_key_values
|
|
||||||
past_sequence_length = output.sequences.shape[-1] - 1
|
|
||||||
self._check_past_key_values_for_generate(
|
|
||||||
internal_batch_size,
|
|
||||||
past_key_values,
|
|
||||||
seq_length=past_sequence_length,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwrite because InstructBLIPVideo cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
|
# overwrite because InstructBLIPVideo cannot generate only from input ids, and requires `pixel` values and `qformer_input_ids` in all cases to be present
|
||||||
@pytest.mark.generate
|
@pytest.mark.generate
|
||||||
def test_left_padding_compatibility(self):
|
def test_left_padding_compatibility(self):
|
||||||
|
|||||||
@@ -468,12 +468,12 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||||
# overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape
|
# overwrite because LED does not have (bs, num_heads, seq_len, seq_len) shape
|
||||||
encoder_expected_shape = (
|
encoder_expected_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
seq_length,
|
prompt_length,
|
||||||
self.model_tester.attention_window // 2 * 2 + 1,
|
self.model_tester.attention_window // 2 * 2 + 1,
|
||||||
)
|
)
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
|
|||||||
@@ -785,7 +785,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
|
|||||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
|
[self.model_tester.num_attention_heads, block_len, 3 * block_len],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||||
block_len = getattr(self.model_tester, "block_len", None)
|
block_len = getattr(self.model_tester, "block_len", None)
|
||||||
encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len)
|
encoder_expected_shape = (batch_size, 2, config.num_attention_heads, block_len, 3 * block_len)
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
@@ -920,10 +920,10 @@ class LongT5TGlobalModelTest(LongT5ModelTest):
|
|||||||
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
|
[self.model_tester.num_attention_heads, block_len, 3 * block_len + global_seq_len],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||||
block_len = getattr(self.model_tester, "block_len", None)
|
block_len = getattr(self.model_tester, "block_len", None)
|
||||||
global_block_size = getattr(self.model_tester, "global_block_size", None)
|
global_block_size = getattr(self.model_tester, "global_block_size", None)
|
||||||
global_seq_length = seq_length // global_block_size
|
global_seq_length = prompt_length // global_block_size
|
||||||
encoder_expected_shape = (
|
encoder_expected_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
2,
|
2,
|
||||||
|
|||||||
@@ -323,32 +323,37 @@ class MllamaForConditionalGenerationModelTest(ModelTesterMixin, GenerationTester
|
|||||||
torch.testing.assert_close(out_embeds, out_ids)
|
torch.testing.assert_close(out_embeds, out_ids)
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
# Mllama has cross attention layers and those have a different shape than normal attention layers
|
# Mllama has cross attention layers and those have a different shape than normal attention layers
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||||
|
|
||||||
cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]
|
cross_attention_layers = self.model_tester.text_config["cross_attention_layers"]
|
||||||
|
use_cache = decoder_past_key_values is not None
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
for generated_length, iter_attentions in enumerate(attentions):
|
||||||
tgt_len = min_length + idx if not use_cache else 1
|
# regardless of using cache, the first forward pass will have the full prompt as input
|
||||||
src_len = min_length + idx
|
if use_cache and generated_length > 0:
|
||||||
|
model_input_length = 1
|
||||||
|
else:
|
||||||
|
model_input_length = prompt_length + generated_length
|
||||||
|
query_length = prompt_length + generated_length
|
||||||
|
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
tgt_len,
|
model_input_length,
|
||||||
src_len,
|
query_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_shape_cross = (
|
expected_shape_cross = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
tgt_len,
|
model_input_length,
|
||||||
self.model_tester.image_length,
|
self.model_tester.image_length,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -575,77 +575,12 @@ class MoshiTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
return config, filtered_inputs_dict
|
return config, filtered_inputs_dict
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_generate_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
||||||
):
|
|
||||||
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
|
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
super()._check_generate_outputs(
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
|
||||||
[True] * len(hidden_states),
|
|
||||||
)
|
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
|
||||||
seq_len = min_length if idx == 0 else 1
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
|
||||||
# check hidden size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
|
||||||
[expected_shape] * len(iter_hidden_states),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_outputs(self, output, config, use_cache=False, num_return_sequences=1, num_beams=1):
|
|
||||||
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
|
|
||||||
super()._check_outputs(
|
|
||||||
output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
|
output, config, use_cache=True, num_return_sequences=num_return_sequences, num_beams=num_beams
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
||||||
):
|
|
||||||
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
|
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
|
||||||
[True] * len(hidden_states),
|
|
||||||
)
|
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
|
||||||
seq_len = 1
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
|
||||||
# check hidden size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
|
||||||
[expected_shape] * len(iter_hidden_states),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
||||||
):
|
|
||||||
# Overwrite because the generate method actually alway uses `inputs_embeds` so `use_cache` is always `True`
|
|
||||||
self.assertIsInstance(attentions, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
|
||||||
)
|
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
|
||||||
tgt_len = 1
|
|
||||||
src_len = min_length + idx
|
|
||||||
|
|
||||||
expected_shape = (
|
|
||||||
batch_size * num_beam_groups,
|
|
||||||
config.num_attention_heads,
|
|
||||||
tgt_len,
|
|
||||||
src_len,
|
|
||||||
)
|
|
||||||
# check attn size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -399,11 +399,11 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||||
encoder_expected_shape = (
|
encoder_expected_shape = (
|
||||||
batch_size,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
math.ceil(seq_length / config.block_size),
|
math.ceil(prompt_length / config.block_size),
|
||||||
config.block_size,
|
config.block_size,
|
||||||
config.block_size + config.num_global_tokens,
|
config.block_size + config.num_global_tokens,
|
||||||
)
|
)
|
||||||
@@ -413,8 +413,8 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
[encoder_expected_shape] * len(attentions),
|
[encoder_expected_shape] * len(attentions),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
||||||
encoder_expected_shape = (batch_size, self.round_up(seq_length, config.block_size), config.hidden_size)
|
encoder_expected_shape = (batch_size, self.round_up(prompt_length, config.block_size), config.hidden_size)
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
# Only the last layer will have the hidden states truncated back to token level
|
# Only the last layer will have the hidden states truncated back to token level
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
@@ -424,7 +424,7 @@ class PegasusXModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterM
|
|||||||
# Only the last layer will have the hidden states truncated back to token level
|
# Only the last layer will have the hidden states truncated back to token level
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
hidden_states[-1][0].shape,
|
hidden_states[-1][0].shape,
|
||||||
(batch_size, seq_length, config.hidden_size),
|
(batch_size, prompt_length, config.hidden_size),
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
|
|||||||
@@ -753,20 +753,20 @@ class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTeste
|
|||||||
text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)
|
text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)
|
||||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||||
|
|
||||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, prompt_length):
|
||||||
# overwrite because # pix2struct seq length depends on image inputs
|
# overwrite because # pix2struct seq length depends on image inputs
|
||||||
seq_length = self.model_tester.max_patches
|
prompt_length = self.model_tester.max_patches
|
||||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
encoder_expected_shape = (batch_size, config.num_attention_heads, prompt_length, prompt_length)
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_attentions.shape for layer_attentions in attentions],
|
[layer_attentions.shape for layer_attentions in attentions],
|
||||||
[encoder_expected_shape] * len(attentions),
|
[encoder_expected_shape] * len(attentions),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, prompt_length):
|
||||||
# overwrite because # pix2struct seq length depends on image inputs
|
# overwrite because # pix2struct seq length depends on image inputs
|
||||||
seq_length = self.model_tester.max_patches
|
prompt_length = self.model_tester.max_patches
|
||||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
encoder_expected_shape = (batch_size, prompt_length, config.hidden_size)
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||||
|
|||||||
@@ -367,9 +367,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_attentions_for_generate(self, *args, **kwargs):
|
|
||||||
return True # Model does not return attention
|
|
||||||
|
|
||||||
@unittest.skip(reason="Past key values are not returned")
|
@unittest.skip(reason="Past key values are not returned")
|
||||||
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
def test_prompt_lookup_decoding_matches_greedy_search(self):
|
||||||
pass
|
pass
|
||||||
@@ -382,9 +379,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
def test_model_parallel_beam_search(self):
|
def test_model_parallel_beam_search(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_past_key_values_for_generate(self, *args, **kwargs):
|
|
||||||
return True
|
|
||||||
|
|
||||||
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
@unittest.skip(reason="Rely on `past_key_values` to crop the assistant pkv. Not supported")
|
||||||
def test_assisted_decoding_matches_greedy_search(self):
|
def test_assisted_decoding_matches_greedy_search(self):
|
||||||
pass
|
pass
|
||||||
@@ -397,25 +391,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
|||||||
def test_assisted_decoding_sample(self):
|
def test_assisted_decoding_sample(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
|
||||||
):
|
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
|
||||||
self.assertListEqual(
|
|
||||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
|
||||||
[True] * len(hidden_states),
|
|
||||||
)
|
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
|
||||||
seq_len = min_length + idx if not use_cache else 1
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
|
||||||
# check hidden size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
|
||||||
[expected_shape] * len(iter_hidden_states),
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skip(reason="TODO @arthurzucker not super important and failing.")
|
@unittest.skip(reason="TODO @arthurzucker not super important and failing.")
|
||||||
def test_initialization(self):
|
def test_initialization(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -620,36 +620,42 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
|
# NOTE (joao): this function is substancially different from the original, the attention has different
|
||||||
|
# *number* of shapes in certain conditions
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
for generated_length, iter_attentions in enumerate(attentions):
|
||||||
tgt_len = min_length + idx if not use_cache else 1
|
use_cache = decoder_past_key_values is not None and generated_length > 0
|
||||||
num_chunks = tgt_len // config.local_attn_chunk_length + (tgt_len % config.local_attn_chunk_length != 0)
|
|
||||||
tgt_chunk_len = config.local_attn_chunk_length
|
model_input_length = prompt_length + generated_length if not use_cache else 1
|
||||||
src_chunk_len = config.local_attn_chunk_length * (
|
num_chunks = model_input_length // config.local_attn_chunk_length + (
|
||||||
|
model_input_length % config.local_attn_chunk_length != 0
|
||||||
|
)
|
||||||
|
model_input_chunk_len = config.local_attn_chunk_length
|
||||||
|
query_chunk_len = config.local_attn_chunk_length * (
|
||||||
1 + config.local_num_chunks_after + config.local_num_chunks_before
|
1 + config.local_num_chunks_after + config.local_num_chunks_before
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
tgt_len,
|
model_input_length,
|
||||||
min_length // config.local_attn_chunk_length + 1 + idx,
|
prompt_length // config.local_attn_chunk_length + generated_length,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
num_chunks,
|
num_chunks,
|
||||||
tgt_chunk_len,
|
model_input_chunk_len,
|
||||||
src_chunk_len,
|
query_chunk_len,
|
||||||
)
|
)
|
||||||
# check attn size
|
# check attn size
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
@@ -657,25 +663,29 @@ class ReformerLocalAttnModelTest(ReformerTesterMixin, GenerationTesterMixin, Mod
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||||
):
|
):
|
||||||
|
# NOTE (joao): this function is substancially different from the original, the hidden states have different
|
||||||
|
# length in certain conditions
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||||
[True] * len(hidden_states),
|
[True] * len(hidden_states),
|
||||||
)
|
)
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
for generation_length, iter_hidden_states in enumerate(hidden_states):
|
||||||
seq_len = min_length + idx
|
use_cache_this_iter = use_cache and generation_length > 0
|
||||||
seq_len = config.local_attn_chunk_length * (
|
model_input_length = prompt_length + generation_length
|
||||||
seq_len // config.local_attn_chunk_length + (seq_len % config.local_attn_chunk_length != 0)
|
model_output_length = config.local_attn_chunk_length * (
|
||||||
|
model_input_length // config.local_attn_chunk_length
|
||||||
|
+ (model_input_length % config.local_attn_chunk_length != 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache_this_iter:
|
||||||
seq_len = 1
|
model_output_length = 1
|
||||||
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
expected_shape = (batch_size, model_output_length, config.hidden_size)
|
||||||
# check hidden size
|
# check hidden size
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||||
@@ -789,37 +799,42 @@ class ReformerLSHAttnModelTest(
|
|||||||
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
self.config_tester = ConfigTester(self, config_class=ReformerConfig, hidden_size=37)
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
|
# NOTE (joao): this function is substancially different from the original, the attention has different
|
||||||
|
# *number* of shapes in certain conditions
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
[isinstance(iter_attentions, list) for iter_attentions in attentions], [True] * len(attentions)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
for generated_length, iter_attentions in enumerate(attentions):
|
||||||
tgt_len = min_length + idx if not use_cache else 1
|
use_cache = decoder_past_key_values is not None and generated_length > 0
|
||||||
num_chunks = tgt_len // config.lsh_attn_chunk_length + (tgt_len % config.lsh_attn_chunk_length != 0)
|
model_input_len = prompt_length + generated_length if not use_cache else 1
|
||||||
tgt_chunk_len = config.lsh_attn_chunk_length
|
num_chunks = model_input_len // config.lsh_attn_chunk_length + (
|
||||||
src_chunk_len = config.lsh_attn_chunk_length * (
|
model_input_len % config.lsh_attn_chunk_length != 0
|
||||||
|
)
|
||||||
|
model_input_chunk_len = config.lsh_attn_chunk_length
|
||||||
|
query_chunk_len = config.lsh_attn_chunk_length * (
|
||||||
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
|
1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
config.num_hashes,
|
config.num_hashes,
|
||||||
tgt_len,
|
model_input_len,
|
||||||
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
|
config.num_hashes * (1 + config.lsh_num_chunks_after + config.lsh_num_chunks_before),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
expected_shape = (
|
expected_shape = (
|
||||||
batch_size * num_beam_groups,
|
batch_size,
|
||||||
config.num_attention_heads,
|
config.num_attention_heads,
|
||||||
num_chunks * config.num_hashes,
|
num_chunks * config.num_hashes,
|
||||||
tgt_chunk_len,
|
model_input_chunk_len,
|
||||||
src_chunk_len,
|
query_chunk_len,
|
||||||
)
|
)
|
||||||
# check attn size
|
# check attn size
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
@@ -827,25 +842,29 @@ class ReformerLSHAttnModelTest(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||||
):
|
):
|
||||||
|
# NOTE (joao): this function is substancially different from the original, the hidden states have different
|
||||||
|
# length in certain conditions
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
[isinstance(iter_hidden_states, list) for iter_hidden_states in hidden_states],
|
||||||
[True] * len(hidden_states),
|
[True] * len(hidden_states),
|
||||||
)
|
)
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
for generation_length, iter_hidden_states in enumerate(hidden_states):
|
||||||
seq_len = min_length + idx if not use_cache else 1
|
use_cache_this_iter = use_cache and generation_length > 0
|
||||||
seq_len = config.lsh_attn_chunk_length * (
|
model_input_length = prompt_length + generation_length
|
||||||
seq_len // config.lsh_attn_chunk_length + (seq_len % config.lsh_attn_chunk_length != 0)
|
model_output_length = config.local_attn_chunk_length * (
|
||||||
|
model_input_length // config.local_attn_chunk_length
|
||||||
|
+ (model_input_length % config.local_attn_chunk_length != 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache_this_iter:
|
||||||
seq_len = 1
|
model_output_length = 1
|
||||||
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
expected_shape = (batch_size, model_output_length, config.hidden_size)
|
||||||
# check hidden size
|
# check hidden size
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
||||||
|
|||||||
@@ -416,48 +416,6 @@ class TFSpeech2TextModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.T
|
|||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
|
||||||
batch_size, seq_length = input_ids.shape[:2]
|
|
||||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
|
||||||
gen_len = (
|
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# scores
|
|
||||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
|
||||||
|
|
||||||
# Attentions
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_attention_for_generate(
|
|
||||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
# decoder
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_attentions,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_hidden_states_for_generate(
|
|
||||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_hidden_states,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||||
# `input_features`
|
# `input_features`
|
||||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||||
|
|||||||
@@ -527,48 +527,6 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _check_outputs(self, output, input_ids, config, use_cache=False, num_return_sequences=1):
|
|
||||||
batch_size, mel, seq_length = input_ids.shape
|
|
||||||
subsampled_seq_length = self.model_tester.get_subsampled_output_lengths(seq_length)
|
|
||||||
num_sequences_in_output = batch_size * num_return_sequences
|
|
||||||
gen_len = (
|
|
||||||
output.sequences.shape[-1] - 1 if config.is_encoder_decoder else output.sequences.shape[-1] - seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# scores
|
|
||||||
self._check_scores(num_sequences_in_output, output.scores, length=gen_len, config=config)
|
|
||||||
|
|
||||||
# Attentions
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_attention_for_generate(
|
|
||||||
output.encoder_attentions, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
# decoder
|
|
||||||
self._check_attentions_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_attentions,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Hidden States
|
|
||||||
# encoder
|
|
||||||
self._check_encoder_hidden_states_for_generate(
|
|
||||||
output.encoder_hidden_states, batch_size, config, subsampled_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
# decoder
|
|
||||||
self._check_hidden_states_for_generate(
|
|
||||||
num_sequences_in_output,
|
|
||||||
output.decoder_hidden_states,
|
|
||||||
min_length=1,
|
|
||||||
max_length=output.sequences.shape[-1],
|
|
||||||
config=config,
|
|
||||||
use_cache=use_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
# overwritten from parent due to the inability to work when non-text inputs are not passed AND because the input is
|
||||||
# `input_features`
|
# `input_features`
|
||||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||||
|
|||||||
@@ -1607,6 +1607,11 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_generate_compile_model_forward(self):
|
def test_generate_compile_model_forward(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# TODO (joao, eustache): fix me :)
|
||||||
|
@unittest.skip(reason="A CUDA exception is thrown when storing extra outputs")
|
||||||
|
def test_generate_compilation_all_outputs(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
|
|||||||
@@ -473,50 +473,24 @@ class XLMModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
|
|||||||
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
|
self.model_tester.create_and_check_xlm_for_multiple_choice(*config_and_inputs)
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
self.assertIsInstance(attentions, tuple)
|
# adds PAD dummy token, expected shape is off by 1
|
||||||
self.assertListEqual(
|
prompt_length += 1
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
output_length += 1
|
||||||
|
super()._check_attentions_for_generate(
|
||||||
|
batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_attentions in enumerate(attentions):
|
|
||||||
# adds PAD dummy token
|
|
||||||
tgt_len = min_length + idx + 1
|
|
||||||
src_len = min_length + idx + 1
|
|
||||||
|
|
||||||
expected_shape = (
|
|
||||||
batch_size * num_beam_groups,
|
|
||||||
config.num_attention_heads,
|
|
||||||
tgt_len,
|
|
||||||
src_len,
|
|
||||||
)
|
|
||||||
# check attn size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_attention.shape for layer_attention in iter_attentions], [expected_shape] * len(iter_attentions)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||||
):
|
):
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
# adds PAD dummy token, expected shape is off by 1
|
||||||
self.assertListEqual(
|
prompt_length += 1
|
||||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
output_length += 1
|
||||||
[True] * len(hidden_states),
|
super()._check_hidden_states_for_generate(
|
||||||
|
batch_size, hidden_states, prompt_length, output_length, config, use_cache
|
||||||
)
|
)
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
|
||||||
# adds PAD dummy token
|
|
||||||
seq_len = min_length + idx + 1
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
|
||||||
# check hidden size
|
|
||||||
self.assertListEqual(
|
|
||||||
[layer_hidden_states.shape for layer_hidden_states in iter_hidden_states],
|
|
||||||
[expected_shape] * len(iter_hidden_states),
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
|
|||||||
@@ -636,57 +636,52 @@ class XLNetModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
weight.data.fill_(3)
|
weight.data.fill_(3)
|
||||||
|
|
||||||
def _check_hidden_states_for_generate(
|
def _check_hidden_states_for_generate(
|
||||||
self, batch_size, hidden_states, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, hidden_states, prompt_length, output_length, config, use_cache=False
|
||||||
):
|
):
|
||||||
self.assertIsInstance(hidden_states, tuple)
|
self.assertIsInstance(hidden_states, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
[isinstance(iter_hidden_states, tuple) for iter_hidden_states in hidden_states],
|
||||||
[True] * len(hidden_states),
|
[True] * len(hidden_states),
|
||||||
)
|
)
|
||||||
self.assertEqual(len(hidden_states), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(hidden_states), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, iter_hidden_states in enumerate(hidden_states):
|
for generated_length, iter_hidden_states in enumerate(hidden_states):
|
||||||
# check hidden size
|
# check hidden size
|
||||||
for i, layer_hidden_states in enumerate(iter_hidden_states):
|
for i, layer_hidden_states in enumerate(iter_hidden_states):
|
||||||
# every 2nd tensor is from extra stream
|
# every 2nd tensor is from extra stream
|
||||||
if i % 2 != 0:
|
if i % 2 != 0:
|
||||||
seq_len = 1
|
model_output_length = 1
|
||||||
else:
|
else:
|
||||||
# for first item dummy PAD token is appended so need one more
|
# for first item dummy PAD token is appended so need one more
|
||||||
# else offset+dummy_token when using cache
|
# else offset+dummy_token when using cache
|
||||||
seq_len = (min_length + 1) if idx == 0 else 3
|
model_output_length = (prompt_length + 1) if generated_length == 0 else 3
|
||||||
|
|
||||||
expected_shape = (batch_size * num_beam_groups, seq_len, config.hidden_size)
|
expected_shape = (batch_size, model_output_length, config.hidden_size)
|
||||||
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
self.assertEqual(layer_hidden_states.shape, expected_shape)
|
||||||
|
|
||||||
def _check_attentions_for_generate(
|
def _check_attentions_for_generate(
|
||||||
self, batch_size, attentions, min_length, max_length, config, use_cache=False, num_beam_groups=1
|
self, batch_size, attentions, prompt_length, output_length, config, decoder_past_key_values
|
||||||
):
|
):
|
||||||
self.assertIsInstance(attentions, tuple)
|
self.assertIsInstance(attentions, tuple)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
[isinstance(iter_attentions, tuple) for iter_attentions in attentions], [True] * len(attentions)
|
||||||
)
|
)
|
||||||
self.assertEqual(len(attentions), (max_length - min_length) * num_beam_groups)
|
self.assertEqual(len(attentions), (output_length - prompt_length))
|
||||||
|
|
||||||
for idx, attentions_item in enumerate(attentions):
|
for generated_length, attentions_item in enumerate(attentions):
|
||||||
for iter_attentions in attentions_item:
|
for iter_attentions in attentions_item:
|
||||||
tgt_len = min_length
|
model_input_length = prompt_length
|
||||||
|
|
||||||
# for first item dummy PAD token is appended so need one more
|
# for first item dummy PAD token is appended so need one more
|
||||||
# every token after consists of offset+dummy_token length when using cache
|
# every token after consists of offset+dummy_token length when using cache
|
||||||
if idx == 0:
|
if generated_length == 0:
|
||||||
tgt_len += 1
|
model_input_length += 1
|
||||||
else:
|
else:
|
||||||
tgt_len = 3
|
model_input_length = 3
|
||||||
|
|
||||||
src_len = min_length + idx + 1
|
query_length = prompt_length + generated_length + 1
|
||||||
|
|
||||||
expected_shape = (
|
expected_shape = (batch_size, config.num_attention_heads, model_input_length, query_length)
|
||||||
batch_size * num_beam_groups,
|
|
||||||
config.num_attention_heads,
|
|
||||||
tgt_len,
|
|
||||||
src_len,
|
|
||||||
)
|
|
||||||
# check attn size
|
# check attn size
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
[layer_attention.shape for layer_attention in iter_attentions],
|
[layer_attention.shape for layer_attention in iter_attentions],
|
||||||
|
|||||||
Reference in New Issue
Block a user