[whisper] static kv cache (#31166)
* make work with cache abstraction * correct for static cache * hacks for compile * make fast * fix * fix pos ids * generate * fix sdpa * fix sdpa cache pos * fix fa2 * clean fa2 * integrate cache into generate * make style * copies * more copies * update eager * update sdpa * update fa2 * simplify * use cache pos * always compute cross-cache for debug * avoid recompiles Co-authored-by: Arthur Zucker <arthur@huggingface.co> * fix fix * fix fix fix * more fix * try encoder-decoder cache (too messy) * revert encoder-decoder cache * check cross-attn cache * use enc-dec dataclass * use richer enc-dec dataclass * clean-up * revert static cache changes * small fixes * revert to cpu flag * fix copies * add static slow test * past k/v docstring * more docstrings * cache_position docstrings * add to docs * add enc-dec cache to docs * make style * fix after rebase * fix beam * style * fix generation strategies * fix most decoder-only tests * style * skip test * more clean up * small docstrings * Apply suggestions from code review Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * add todo * only crop self-attn * check cache in mixin * style * fix re-compile after rebase * move `is_updated` logic to enc-dec wrapper * revert back * revert cache back * finalise design * fix * fix fix * style * Update src/transformers/cache_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * deprecate * updates * final updates * style * style --------- Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
@@ -391,6 +391,12 @@ A [`Constraint`] can be used to force the generation to include specific tokens
|
|||||||
- get_seq_length
|
- get_seq_length
|
||||||
- reset
|
- reset
|
||||||
|
|
||||||
|
[[autodoc]] EncoderDecoderCache
|
||||||
|
- get_seq_length
|
||||||
|
- to_legacy_cache
|
||||||
|
- from_legacy_cache
|
||||||
|
- reset
|
||||||
|
- reorder_cache
|
||||||
|
|
||||||
## Watermark Utils
|
## Watermark Utils
|
||||||
|
|
||||||
|
|||||||
@@ -52,8 +52,6 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
|||||||
>>> # Select an audio file and read it:
|
>>> # Select an audio file and read it:
|
||||||
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
>>> audio_sample = ds[0]["audio"]
|
>>> audio_sample = ds[0]["audio"]
|
||||||
>>> waveform = audio_sample["array"]
|
|
||||||
>>> sampling_rate = audio_sample["sampling_rate"]
|
|
||||||
|
|
||||||
>>> # Load the Whisper model in Hugging Face format:
|
>>> # Load the Whisper model in Hugging Face format:
|
||||||
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
@@ -61,7 +59,7 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
|||||||
|
|
||||||
>>> # Use the model and processor to transcribe the audio:
|
>>> # Use the model and processor to transcribe the audio:
|
||||||
>>> input_features = processor(
|
>>> input_features = processor(
|
||||||
... waveform, sampling_rate=sampling_rate, return_tensors="pt"
|
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
|
||||||
... ).input_features
|
... ).input_features
|
||||||
|
|
||||||
>>> # Generate token ids
|
>>> # Generate token ids
|
||||||
@@ -74,6 +72,49 @@ Here is a step-by-step guide to transcribing an audio sample using a pre-trained
|
|||||||
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Whisper is compatible with the following optimisations:
|
||||||
|
- [PyTorch Scaled Dot Product Attention (SDPA)](../perf_infer_gpu_one#pytorch-scaled-dot-product-attention): flash attention and memory-efficient attention kernels. Enabled by default for `torch>=2.1.1`.
|
||||||
|
- [Flash Attention 2](../perf_infer_gpu_one#flashattention-2): improved implementation of flash attention through better parallelism and work partitioning.
|
||||||
|
- [torch.compile](../llm_optims#static-kv-cache-and-torchcompile): JIT-compile the forward pass to dispatch to efficient fused kernels.
|
||||||
|
|
||||||
|
As an example, the following codesnippet enables SDPA and `torch.compile` for up to 5x faster inference:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from datasets import load_dataset
|
||||||
|
>>> from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
||||||
|
|
||||||
|
>>> # Select an audio file and read it:
|
||||||
|
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||||
|
>>> audio_sample = ds[0]["audio"]
|
||||||
|
|
||||||
|
>>> # Load the Whisper model with SDPA attention
|
||||||
|
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
>>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", attn_implementation="sdpa")
|
||||||
|
|
||||||
|
>>> # Enable static cache and compile the forward pass
|
||||||
|
>>> model.generation_config.cache_implementation = "static"
|
||||||
|
>>> model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||||
|
|
||||||
|
>>> # Use the model and processor to transcribe the audio:
|
||||||
|
>>> input_features = processor(
|
||||||
|
... audio_sample["array"], sampling_rate=audio_sample["sampling_rate"], return_tensors="pt"
|
||||||
|
... ).input_features
|
||||||
|
|
||||||
|
>>> # Compile the forward pass
|
||||||
|
>>> _ = model.generate(input_features)
|
||||||
|
|
||||||
|
>>> # Generate token ids using compiled graph (fast!)
|
||||||
|
>>> predicted_ids = model.generate(input_features)
|
||||||
|
|
||||||
|
>>> # Decode token ids to text
|
||||||
|
>>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
|
||||||
|
|
||||||
|
>>> transcription[0]
|
||||||
|
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
||||||
|
```
|
||||||
|
|
||||||
|
For more details on each optimisation, refer to the documentation linked above.
|
||||||
|
|
||||||
## Resources
|
## Resources
|
||||||
|
|
||||||
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with Whisper. If you're interested in submitting a resource to be included here, please feel free to open a Pull Request and we'll review it! The resource should ideally demonstrate something new instead of duplicating an existing resource.
|
||||||
|
|||||||
@@ -1212,6 +1212,7 @@ else:
|
|||||||
"Cache",
|
"Cache",
|
||||||
"CacheConfig",
|
"CacheConfig",
|
||||||
"DynamicCache",
|
"DynamicCache",
|
||||||
|
"EncoderDecoderCache",
|
||||||
"HQQQuantizedCache",
|
"HQQQuantizedCache",
|
||||||
"QuantizedCache",
|
"QuantizedCache",
|
||||||
"QuantizedCacheConfig",
|
"QuantizedCacheConfig",
|
||||||
@@ -5895,6 +5896,7 @@ if TYPE_CHECKING:
|
|||||||
Cache,
|
Cache,
|
||||||
CacheConfig,
|
CacheConfig,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
|
EncoderDecoderCache,
|
||||||
HQQQuantizedCache,
|
HQQQuantizedCache,
|
||||||
QuantizedCache,
|
QuantizedCache,
|
||||||
QuantizedCacheConfig,
|
QuantizedCacheConfig,
|
||||||
|
|||||||
@@ -858,8 +858,12 @@ class StaticCache(Cache):
|
|||||||
k_out = self.key_cache[layer_idx]
|
k_out = self.key_cache[layer_idx]
|
||||||
v_out = self.value_cache[layer_idx]
|
v_out = self.value_cache[layer_idx]
|
||||||
|
|
||||||
k_out[:, :, cache_position] = key_states
|
if cache_position is None:
|
||||||
v_out[:, :, cache_position] = value_states
|
k_out.copy_(key_states)
|
||||||
|
v_out.copy_(value_states)
|
||||||
|
else:
|
||||||
|
k_out[:, :, cache_position] = key_states
|
||||||
|
v_out[:, :, cache_position] = value_states
|
||||||
|
|
||||||
return k_out, v_out
|
return k_out, v_out
|
||||||
|
|
||||||
@@ -971,6 +975,158 @@ class SlidingWindowCache(StaticCache):
|
|||||||
# no matter how long the sentence is
|
# no matter how long the sentence is
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.key_cache.zero_()
|
||||||
|
self.value_cache.zero_()
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderCache(Cache):
|
||||||
|
"""
|
||||||
|
Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
|
||||||
|
cross-attention caches.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, self_attention_cache: Cache, cross_attention_cache: Cache):
|
||||||
|
self.self_attention_cache = self_attention_cache
|
||||||
|
self.cross_attention_cache = cross_attention_cache
|
||||||
|
|
||||||
|
self.is_updated = {}
|
||||||
|
for layer_idx in range(len(cross_attention_cache.key_cache)):
|
||||||
|
self.is_updated[layer_idx] = bool(cross_attention_cache.get_seq_length(layer_idx) > 0)
|
||||||
|
|
||||||
|
def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
|
||||||
|
"""
|
||||||
|
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
|
||||||
|
sequence length.
|
||||||
|
"""
|
||||||
|
if layer_idx < len(self):
|
||||||
|
return (
|
||||||
|
self.self_attention_cache.key_cache[layer_idx],
|
||||||
|
self.self_attention_cache.value_cache[layer_idx],
|
||||||
|
self.cross_attention_cache.key_cache[layer_idx],
|
||||||
|
self.cross_attention_cache.key_cache[layer_idx],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
|
||||||
|
to the number of layers in the model.
|
||||||
|
"""
|
||||||
|
return len(self.self_attention_cache)
|
||||||
|
|
||||||
|
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
|
||||||
|
"""Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
|
||||||
|
legacy_cache = ()
|
||||||
|
if len(self.cross_attention_cache) > 0:
|
||||||
|
for self_attn, cross_attn in zip(
|
||||||
|
self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
|
||||||
|
):
|
||||||
|
legacy_cache += (self_attn + cross_attn,)
|
||||||
|
else:
|
||||||
|
legacy_cache = self.self_attention_cache.to_legacy_cache()
|
||||||
|
return legacy_cache
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_legacy_cache(
|
||||||
|
cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
||||||
|
) -> "EncoderDecoderCache":
|
||||||
|
"""Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
|
||||||
|
cache = cls(self_attention_cache=DynamicCache(), cross_attention_cache=DynamicCache())
|
||||||
|
if past_key_values is not None:
|
||||||
|
for layer_idx in range(len(past_key_values)):
|
||||||
|
key_states, value_states = past_key_values[layer_idx][:2]
|
||||||
|
cache.self_attention_cache.update(key_states, value_states, layer_idx)
|
||||||
|
if len(past_key_values[layer_idx]) > 2:
|
||||||
|
key_states, value_states = past_key_values[layer_idx][2:]
|
||||||
|
cache.cross_attention_cache.update(key_states, value_states, layer_idx)
|
||||||
|
cache.is_updated[layer_idx] = True
|
||||||
|
return cache
|
||||||
|
|
||||||
|
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||||||
|
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
|
||||||
|
if len(self.self_attention_cache.key_cache) <= layer_idx:
|
||||||
|
return 0
|
||||||
|
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
if hasattr(self.self_attention_cache, "reset"):
|
||||||
|
self.self_attention_cache.reset()
|
||||||
|
if hasattr(self.cross_attention_cache, "reset"):
|
||||||
|
self.cross_attention_cache.reset()
|
||||||
|
elif not hasattr(self.self_attention_cache, "reset") and not hasattr(self.cross_attention_cache, "reset"):
|
||||||
|
raise ValueError(
|
||||||
|
"Neither self nor cross-attention cache have valid `.reset()` methods. `.reset()` should "
|
||||||
|
"only be called on compatible cache classes, such as `StaticCache` or `SlidingWindowCache`. "
|
||||||
|
f"Got {self.self_attention_cache.__str__()} for the self attention cache and "
|
||||||
|
f"{self.cross_attention_cache.__str__()} for the cross attention cache."
|
||||||
|
)
|
||||||
|
for layer_idx in self.is_updated:
|
||||||
|
self.is_updated[layer_idx] = False
|
||||||
|
|
||||||
|
def reorder_cache(self, beam_idx: torch.LongTensor):
|
||||||
|
"""Reorders the cache for beam search, given the selected beam indices."""
|
||||||
|
self.self_attention_cache.reorder_cache(beam_idx)
|
||||||
|
self.cross_attention_cache.reorder_cache(beam_idx)
|
||||||
|
|
||||||
|
def check_dynamic_cache(self, method: str):
|
||||||
|
if not (
|
||||||
|
isinstance(self.self_attention_cache, DynamicCache)
|
||||||
|
and isinstance(self.cross_attention_cache, DynamicCache)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
|
||||||
|
f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(gante, sanchit-gandhi): move following functionality into `.generate`
|
||||||
|
def crop(self, maximum_length: int):
|
||||||
|
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
|
||||||
|
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
|
||||||
|
self.check_dynamic_cache(self.crop.__name__)
|
||||||
|
self.self_attention_cache.crop(maximum_length)
|
||||||
|
|
||||||
|
def batch_split(self, full_batch_size: int, split_size: int) -> "List[EncoderDecoderCache]":
|
||||||
|
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
|
||||||
|
`_split_model_inputs()` in `generation.utils`"""
|
||||||
|
self.check_dynamic_cache(self.batch_split.__name__)
|
||||||
|
self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
|
||||||
|
cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
|
||||||
|
|
||||||
|
out = []
|
||||||
|
for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
|
||||||
|
out.append(EncoderDecoderCache(self_attn, cross_attn))
|
||||||
|
return out
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_batch_splits(cls, splits: List["EncoderDecoderCache"]) -> "EncoderDecoderCache":
|
||||||
|
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
|
||||||
|
`generation.utils`"""
|
||||||
|
self_attention_cache = DynamicCache()
|
||||||
|
cross_attention_cache = DynamicCache()
|
||||||
|
for idx in range(len(splits[0])):
|
||||||
|
layer_keys = torch.cat([current.self_attention_cache.key_cache[idx] for current in splits], dim=0)
|
||||||
|
layer_values = torch.cat([current.self_attention_cache.value_cache[idx] for current in splits], dim=0)
|
||||||
|
self_attention_cache.update(layer_keys, layer_values, idx)
|
||||||
|
|
||||||
|
layer_keys = torch.cat([current.cross_attention_cache.key_cache[idx] for current in splits], dim=0)
|
||||||
|
layer_values = torch.cat([current.cross_attention_cache.value_cache[idx] for current in splits], dim=0)
|
||||||
|
cross_attention_cache.update(layer_keys, layer_values, idx)
|
||||||
|
return cls(self_attention_cache, cross_attention_cache)
|
||||||
|
|
||||||
|
def batch_repeat_interleave(self, repeats: int):
|
||||||
|
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
|
||||||
|
self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
|
||||||
|
self.self_attention_cache.batch_repeat_interleave(repeats)
|
||||||
|
self.cross_attention_cache.batch_repeat_interleave(repeats)
|
||||||
|
|
||||||
|
def batch_select_indices(self, indices: torch.Tensor):
|
||||||
|
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
|
||||||
|
self.check_dynamic_cache(self.batch_select_indices.__name__)
|
||||||
|
self.self_attention_cache.batch_select_indices(indices)
|
||||||
|
self.cross_attention_cache.batch_select_indices(indices)
|
||||||
|
|
||||||
|
|
||||||
class HybridCache(Cache):
|
class HybridCache(Cache):
|
||||||
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
|
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from torch import nn
|
|||||||
from ..cache_utils import (
|
from ..cache_utils import (
|
||||||
Cache,
|
Cache,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
|
EncoderDecoderCache,
|
||||||
HQQQuantizedCache,
|
HQQQuantizedCache,
|
||||||
HybridCache,
|
HybridCache,
|
||||||
QuantizedCacheConfig,
|
QuantizedCacheConfig,
|
||||||
@@ -1409,7 +1410,7 @@ class GenerationMixin:
|
|||||||
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device)
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int) -> Cache:
|
def _get_cache(self, cache_implementation: str, max_batch_size: int, max_cache_len: int, model_kwargs) -> Cache:
|
||||||
"""
|
"""
|
||||||
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
|
||||||
new `generate` call requires a larger cache.
|
new `generate` call requires a larger cache.
|
||||||
@@ -1417,28 +1418,46 @@ class GenerationMixin:
|
|||||||
Returns the resulting cache object.
|
Returns the resulting cache object.
|
||||||
"""
|
"""
|
||||||
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
cache_cls: Cache = NEED_SETUP_CACHE_CLASSES_MAPPING[cache_implementation]
|
||||||
|
requires_cross_attention_cache = (
|
||||||
|
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self, "_cache"):
|
||||||
|
cache_to_check = self._cache.self_attention_cache if requires_cross_attention_cache else self._cache
|
||||||
|
|
||||||
if cache_implementation == "sliding_window":
|
if cache_implementation == "sliding_window":
|
||||||
max_cache_len = min(self.config.sliding_window, max_cache_len)
|
max_cache_len = min(self.config.sliding_window, max_cache_len)
|
||||||
|
|
||||||
need_new_cache = (
|
need_new_cache = (
|
||||||
not hasattr(self, "_cache")
|
not hasattr(self, "_cache")
|
||||||
or (not isinstance(self._cache, cache_cls))
|
or (not isinstance(cache_to_check, cache_cls))
|
||||||
or self._cache.max_batch_size != max_batch_size
|
or cache_to_check.max_batch_size != max_batch_size
|
||||||
or self._cache.max_cache_len < max_cache_len
|
or cache_to_check.max_cache_len < max_cache_len
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if requires_cross_attention_cache and hasattr(self, "_cache"):
|
||||||
|
need_new_cache = (
|
||||||
|
need_new_cache
|
||||||
|
or self._cache.cross_attention_cache.max_cache_len != model_kwargs["encoder_outputs"][0].shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
if need_new_cache:
|
if need_new_cache:
|
||||||
if hasattr(self.config, "_pre_quantization_dtype"):
|
if hasattr(self.config, "_pre_quantization_dtype"):
|
||||||
cache_dtype = self.config._pre_quantization_dtype
|
cache_dtype = self.config._pre_quantization_dtype
|
||||||
else:
|
else:
|
||||||
cache_dtype = self.dtype
|
cache_dtype = self.dtype
|
||||||
self._cache = cache_cls(
|
cache_kwargs = {
|
||||||
config=self.config,
|
"config": self.config,
|
||||||
max_batch_size=max_batch_size,
|
"max_batch_size": max_batch_size,
|
||||||
max_cache_len=max_cache_len,
|
"max_cache_len": max_cache_len,
|
||||||
device=self.device,
|
"device": self.device,
|
||||||
dtype=cache_dtype,
|
"dtype": cache_dtype,
|
||||||
)
|
}
|
||||||
|
self._cache = cache_cls(**cache_kwargs)
|
||||||
|
if requires_cross_attention_cache:
|
||||||
|
encoder_kwargs = cache_kwargs.copy()
|
||||||
|
encoder_kwargs["max_cache_len"] = model_kwargs["encoder_outputs"][0].shape[1]
|
||||||
|
self._cache = EncoderDecoderCache(self._cache, cache_cls(**encoder_kwargs))
|
||||||
else:
|
else:
|
||||||
self._cache.reset()
|
self._cache.reset()
|
||||||
return self._cache
|
return self._cache
|
||||||
@@ -1745,6 +1764,7 @@ class GenerationMixin:
|
|||||||
generation_config.cache_implementation,
|
generation_config.cache_implementation,
|
||||||
getattr(generation_config, "num_beams", 1) * batch_size,
|
getattr(generation_config, "num_beams", 1) * batch_size,
|
||||||
generation_config.max_length,
|
generation_config.max_length,
|
||||||
|
model_kwargs,
|
||||||
)
|
)
|
||||||
elif generation_config.cache_implementation == "quantized":
|
elif generation_config.cache_implementation == "quantized":
|
||||||
if not self._supports_quantized_cache:
|
if not self._supports_quantized_cache:
|
||||||
@@ -1776,11 +1796,22 @@ class GenerationMixin:
|
|||||||
# keeps copying the cache thus using much more memory
|
# keeps copying the cache thus using much more memory
|
||||||
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
|
||||||
past = model_kwargs.get("past_key_values", None)
|
past = model_kwargs.get("past_key_values", None)
|
||||||
|
requires_cross_attention_cache = (
|
||||||
|
self.config.is_encoder_decoder or model_kwargs.get("encoder_outputs") is not None
|
||||||
|
)
|
||||||
if past is None:
|
if past is None:
|
||||||
model_kwargs["past_key_values"] = DynamicCache()
|
model_kwargs["past_key_values"] = (
|
||||||
|
DynamicCache()
|
||||||
|
if not requires_cross_attention_cache
|
||||||
|
else EncoderDecoderCache(DynamicCache(), DynamicCache())
|
||||||
|
)
|
||||||
use_dynamic_cache_by_default = True
|
use_dynamic_cache_by_default = True
|
||||||
elif isinstance(past, tuple):
|
elif isinstance(past, tuple):
|
||||||
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
|
model_kwargs["past_key_values"] = (
|
||||||
|
DynamicCache.from_legacy_cache(past)
|
||||||
|
if not requires_cross_attention_cache
|
||||||
|
else EncoderDecoderCache.from_legacy_cache(past)
|
||||||
|
)
|
||||||
use_dynamic_cache_by_default = True
|
use_dynamic_cache_by_default = True
|
||||||
|
|
||||||
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
|
||||||
@@ -2064,7 +2095,7 @@ class GenerationMixin:
|
|||||||
# Convert to legacy cache if needed
|
# Convert to legacy cache if needed
|
||||||
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
|
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
|
||||||
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
|
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
|
||||||
if isinstance(result.past_key_values, DynamicCache):
|
if isinstance(result.past_key_values, (DynamicCache, EncoderDecoderCache)):
|
||||||
result.past_key_values = result.past_key_values.to_legacy_cache()
|
result.past_key_values = result.past_key_values.to_legacy_cache()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -2234,7 +2265,7 @@ class GenerationMixin:
|
|||||||
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
|
||||||
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
|
||||||
if model_kwargs.get("past_key_values") is None or (
|
if model_kwargs.get("past_key_values") is None or (
|
||||||
isinstance(model_kwargs["past_key_values"], Cache)
|
isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
|
||||||
and model_kwargs["past_key_values"].get_seq_length() == 0
|
and model_kwargs["past_key_values"].get_seq_length() == 0
|
||||||
):
|
):
|
||||||
# prepare inputs
|
# prepare inputs
|
||||||
@@ -2323,7 +2354,9 @@ class GenerationMixin:
|
|||||||
# Replicates the new past_key_values to match the `top_k` candidates
|
# Replicates the new past_key_values to match the `top_k` candidates
|
||||||
past = model_kwargs["past_key_values"]
|
past = model_kwargs["past_key_values"]
|
||||||
# If it is a static cache, modify it in-place layer after layer to save memory
|
# If it is a static cache, modify it in-place layer after layer to save memory
|
||||||
if isinstance(past, DynamicCache):
|
if isinstance(past, DynamicCache) or (
|
||||||
|
isinstance(past, EncoderDecoderCache) and isinstance(past.self_attention_cache, DynamicCache)
|
||||||
|
):
|
||||||
past.batch_repeat_interleave(top_k)
|
past.batch_repeat_interleave(top_k)
|
||||||
else:
|
else:
|
||||||
new_key_values = []
|
new_key_values = []
|
||||||
@@ -2350,7 +2383,10 @@ class GenerationMixin:
|
|||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
if isinstance(outputs["past_key_values"], DynamicCache):
|
if isinstance(outputs["past_key_values"], DynamicCache) or (
|
||||||
|
isinstance(outputs["past_key_values"], EncoderDecoderCache)
|
||||||
|
and isinstance(outputs["past_key_values"].self_attention_cache, DynamicCache)
|
||||||
|
):
|
||||||
# Remove past K-V from output since we don't need to stack later
|
# Remove past K-V from output since we don't need to stack later
|
||||||
outputs["past_key_values"] = None
|
outputs["past_key_values"] = None
|
||||||
# Remove last token from past K-V since we don't want to append it at this point
|
# Remove last token from past K-V since we don't want to append it at this point
|
||||||
@@ -2425,7 +2461,10 @@ class GenerationMixin:
|
|||||||
else:
|
else:
|
||||||
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
|
||||||
# Do it in-place layer per layer to save memory
|
# Do it in-place layer per layer to save memory
|
||||||
if isinstance(next_past_key_values, DynamicCache):
|
if isinstance(next_past_key_values, DynamicCache) or (
|
||||||
|
isinstance(next_past_key_values, EncoderDecoderCache)
|
||||||
|
and isinstance(next_past_key_values.self_attention_cache, DynamicCache)
|
||||||
|
):
|
||||||
next_past_key_values.batch_select_indices(augmented_idx)
|
next_past_key_values.batch_select_indices(augmented_idx)
|
||||||
else:
|
else:
|
||||||
new_key_values = []
|
new_key_values = []
|
||||||
@@ -2498,7 +2537,10 @@ class GenerationMixin:
|
|||||||
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
# Contrastive search works by forward looking at the next token, so we need to exclude it from
|
||||||
# `past_key_values` to be consistent with the other decoding methods
|
# `past_key_values` to be consistent with the other decoding methods
|
||||||
if model_kwargs.get("past_key_values") is not None:
|
if model_kwargs.get("past_key_values") is not None:
|
||||||
if isinstance(model_kwargs["past_key_values"], DynamicCache):
|
if isinstance(model_kwargs["past_key_values"], DynamicCache) or (
|
||||||
|
isinstance(model_kwargs["past_key_values"], EncoderDecoderCache)
|
||||||
|
and isinstance(model_kwargs["past_key_values"].self_attention_cache, DynamicCache)
|
||||||
|
):
|
||||||
model_kwargs["past_key_values"].crop(-1)
|
model_kwargs["past_key_values"].crop(-1)
|
||||||
else:
|
else:
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
@@ -2757,7 +2799,7 @@ class GenerationMixin:
|
|||||||
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
|
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
|
||||||
# cache format is standardized, to avoid adding complexity to the codebase.
|
# cache format is standardized, to avoid adding complexity to the codebase.
|
||||||
elif "bloom" in model_class or "gptbigcode" in model_class:
|
elif "bloom" in model_class or "gptbigcode" in model_class:
|
||||||
if not isinstance(past_key_values, DynamicCache):
|
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
|
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
|
||||||
"legacy tuple format or `DynamicCache`"
|
"legacy tuple format or `DynamicCache`"
|
||||||
@@ -3703,8 +3745,12 @@ class GenerationMixin:
|
|||||||
|
|
||||||
# This is needed if return_dict_in_generate is True
|
# This is needed if return_dict_in_generate is True
|
||||||
start_from_empty_dynamic_cache = False
|
start_from_empty_dynamic_cache = False
|
||||||
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache):
|
past_key_values = model_kwargs.get("past_key_values", None)
|
||||||
if len(model_kwargs["past_key_values"]) == 0:
|
if isinstance(past_key_values, DynamicCache) or (
|
||||||
|
isinstance(past_key_values, EncoderDecoderCache)
|
||||||
|
and isinstance(past_key_values.self_attention_cache, DynamicCache)
|
||||||
|
):
|
||||||
|
if len(past_key_values) == 0:
|
||||||
start_from_empty_dynamic_cache = True
|
start_from_empty_dynamic_cache = True
|
||||||
|
|
||||||
this_peer_finished = False
|
this_peer_finished = False
|
||||||
@@ -4022,7 +4068,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
|
|||||||
if isinstance(data, torch.Tensor):
|
if isinstance(data, torch.Tensor):
|
||||||
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
|
||||||
# New cache format
|
# New cache format
|
||||||
elif isinstance(data, DynamicCache):
|
elif isinstance(data, DynamicCache) or (
|
||||||
|
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
|
||||||
|
):
|
||||||
return data.batch_split(full_batch_size, split_size)
|
return data.batch_split(full_batch_size, split_size)
|
||||||
elif isinstance(data, tuple):
|
elif isinstance(data, tuple):
|
||||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||||
@@ -4130,6 +4178,8 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
|
|||||||
# New cache format
|
# New cache format
|
||||||
elif isinstance(data[0], DynamicCache):
|
elif isinstance(data[0], DynamicCache):
|
||||||
return DynamicCache.from_batch_splits(data)
|
return DynamicCache.from_batch_splits(data)
|
||||||
|
elif isinstance(data[0], EncoderDecoderCache):
|
||||||
|
return EncoderDecoderCache.from_batch_splits(data)
|
||||||
elif isinstance(data[0], tuple):
|
elif isinstance(data[0], tuple):
|
||||||
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
|
||||||
if isinstance(data[0][0], tuple):
|
if isinstance(data[0][0], tuple):
|
||||||
|
|||||||
@@ -189,7 +189,11 @@ class WhisperConfig(PretrainedConfig):
|
|||||||
|
|
||||||
model_type = "whisper"
|
model_type = "whisper"
|
||||||
keys_to_ignore_at_inference = ["past_key_values"]
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
attribute_map = {
|
||||||
|
"num_key_value_heads": "encoder_attention_heads",
|
||||||
|
"num_attention_heads": "encoder_attention_heads",
|
||||||
|
"hidden_size": "d_model",
|
||||||
|
}
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -25,7 +25,8 @@ from torch import nn
|
|||||||
from torch.nn import CrossEntropyLoss
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||||
|
from ...modeling_attn_mask_utils import AttentionMaskConverter
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutput,
|
BaseModelOutput,
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
@@ -244,6 +245,7 @@ class WhisperAttention(nn.Module):
|
|||||||
is_decoder: bool = False,
|
is_decoder: bool = False,
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
is_causal: bool = False,
|
is_causal: bool = False,
|
||||||
|
layer_idx: Optional[int] = None,
|
||||||
config: Optional[WhisperConfig] = None,
|
config: Optional[WhisperConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -262,6 +264,14 @@ class WhisperAttention(nn.Module):
|
|||||||
self.is_decoder = is_decoder
|
self.is_decoder = is_decoder
|
||||||
self.is_causal = is_causal
|
self.is_causal = is_causal
|
||||||
|
|
||||||
|
if layer_idx is None and is_decoder:
|
||||||
|
logger.warning_once(
|
||||||
|
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
|
||||||
|
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
|
||||||
|
"when creating this class."
|
||||||
|
)
|
||||||
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||||
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
||||||
@@ -271,84 +281,56 @@ class WhisperAttention(nn.Module):
|
|||||||
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartAttention.forward with BART->whisper
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states) * self.scaling
|
query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz)
|
||||||
# get key, value proj
|
|
||||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
if past_key_value is not None:
|
||||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||||
# the provided `key_value_states` to support prefix tuning
|
if is_cross_attention:
|
||||||
if (
|
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||||
is_cross_attention
|
past_key_value.is_updated[self.layer_idx] = True
|
||||||
and past_key_value is not None
|
past_key_value = past_key_value.cross_attention_cache
|
||||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
else:
|
||||||
):
|
past_key_value = past_key_value.self_attention_cache
|
||||||
|
|
||||||
|
# use key_value_states if cross attention
|
||||||
|
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||||
|
if is_cross_attention and past_key_value and is_updated:
|
||||||
# reuse k,v, cross_attentions
|
# reuse k,v, cross_attentions
|
||||||
key_states = past_key_value[0]
|
key_states = past_key_value.key_cache[self.layer_idx]
|
||||||
value_states = past_key_value[1]
|
value_states = past_key_value.value_cache[self.layer_idx]
|
||||||
elif is_cross_attention:
|
|
||||||
# cross_attentions
|
|
||||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
||||||
elif past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
else:
|
else:
|
||||||
# self_attention
|
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
if past_key_value is not None:
|
||||||
|
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||||
if self.is_decoder:
|
cache_position = cache_position if not is_cross_attention else None
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
key_states, value_states = past_key_value.update(
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
|
||||||
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
|
|
||||||
key_states = key_states.reshape(*proj_shape)
|
|
||||||
value_states = value_states.reshape(*proj_shape)
|
|
||||||
|
|
||||||
src_len = key_states.size(1)
|
|
||||||
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
|
|
||||||
|
|
||||||
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
|
|
||||||
f" {attn_weights.size()}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
if attention_mask.size() != (bsz, 1, tgt_len, src_len):
|
|
||||||
raise ValueError(
|
|
||||||
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
|
|
||||||
)
|
)
|
||||||
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
|
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3))
|
||||||
|
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
@@ -358,42 +340,27 @@ class WhisperAttention(nn.Module):
|
|||||||
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
|
||||||
f" {layer_head_mask.size()}"
|
f" {layer_head_mask.size()}"
|
||||||
)
|
)
|
||||||
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights
|
||||||
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
|
||||||
|
|
||||||
if output_attentions:
|
|
||||||
# this operation is a bit awkward, but it's required to
|
|
||||||
# make sure that attn_weights keeps its gradient.
|
|
||||||
# In order to do so, attn_weights have to be reshaped
|
|
||||||
# twice and have to be reused in the following
|
|
||||||
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
|
||||||
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
|
|
||||||
else:
|
|
||||||
attn_weights_reshaped = None
|
|
||||||
|
|
||||||
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
attn_output = torch.matmul(attn_probs, value_states)
|
||||||
|
|
||||||
attn_output = torch.bmm(attn_probs, value_states)
|
if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
|
||||||
|
|
||||||
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
|
||||||
f" {attn_output.size()}"
|
f" {attn_output.size()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
|
|
||||||
attn_output = attn_output.transpose(1, 2)
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
|
||||||
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
# Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
|
||||||
# partitioned across GPUs when using tensor-parallelism.
|
# partitioned across GPUs when using tensor-parallelism.
|
||||||
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
|
||||||
|
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
return attn_output, attn_weights_reshaped, past_key_value
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Whisper
|
|
||||||
class WhisperFlashAttention2(WhisperAttention):
|
class WhisperFlashAttention2(WhisperAttention):
|
||||||
"""
|
"""
|
||||||
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
|
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
|
||||||
@@ -410,18 +377,21 @@ class WhisperFlashAttention2(WhisperAttention):
|
|||||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||||
|
|
||||||
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
|
|
||||||
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if isinstance(past_key_value, StaticCache):
|
||||||
|
raise ValueError(
|
||||||
|
"The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. "
|
||||||
|
"Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
|
||||||
|
)
|
||||||
# WhisperFlashAttention2 attention does not support output_attentions
|
# WhisperFlashAttention2 attention does not support output_attentions
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
|
raise ValueError("WhisperFlashAttention2 attention does not support output_attentions")
|
||||||
@@ -429,51 +399,45 @@ class WhisperFlashAttention2(WhisperAttention):
|
|||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
bsz, q_len, _ = hidden_states.size()
|
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
|
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||||
# get key, value proj
|
|
||||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
|
||||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
|
||||||
# the provided `key_value_states` to support prefix tuning
|
|
||||||
if (
|
|
||||||
is_cross_attention
|
|
||||||
and past_key_value is not None
|
|
||||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
|
||||||
):
|
|
||||||
# reuse k,v, cross_attentions
|
|
||||||
key_states = past_key_value[0].transpose(1, 2)
|
|
||||||
value_states = past_key_value[1].transpose(1, 2)
|
|
||||||
elif is_cross_attention:
|
|
||||||
# cross_attentions
|
|
||||||
key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
|
|
||||||
value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
|
|
||||||
elif past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
|
|
||||||
value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
|
|
||||||
else:
|
|
||||||
# self_attention
|
|
||||||
key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
|
|
||||||
if self.is_decoder:
|
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
|
|
||||||
|
|
||||||
kv_seq_len = key_states.shape[-2]
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
kv_seq_len += past_key_value[0].shape[-2]
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||||
|
if is_cross_attention:
|
||||||
|
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||||
|
past_key_value.is_updated[self.layer_idx] = True
|
||||||
|
past_key_value = past_key_value.cross_attention_cache
|
||||||
|
else:
|
||||||
|
past_key_value = past_key_value.self_attention_cache
|
||||||
|
|
||||||
|
# use key_value_states if cross attention
|
||||||
|
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||||
|
if is_cross_attention and past_key_value and is_updated:
|
||||||
|
# reuse k,v, cross_attentions
|
||||||
|
key_states = past_key_value.key_cache[self.layer_idx]
|
||||||
|
value_states = past_key_value.value_cache[self.layer_idx]
|
||||||
|
else:
|
||||||
|
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||||
|
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||||
|
if past_key_value is not None:
|
||||||
|
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||||
|
cache_position = cache_position if not is_cross_attention else None
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
|
||||||
|
# We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
|
||||||
|
query_states = query_states.transpose(1, 2)
|
||||||
|
key_states = key_states.transpose(1, 2)
|
||||||
|
value_states = value_states.transpose(1, 2)
|
||||||
|
|
||||||
|
causal_mask = attention_mask
|
||||||
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
|
||||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||||
@@ -502,10 +466,10 @@ class WhisperFlashAttention2(WhisperAttention):
|
|||||||
value_states = value_states.to(target_dtype)
|
value_states = value_states.to(target_dtype)
|
||||||
|
|
||||||
attn_output = self._flash_attention_forward(
|
attn_output = self._flash_attention_forward(
|
||||||
query_states, key_states, value_states, attention_mask, q_len, dropout=self.dropout
|
query_states, key_states, value_states, causal_mask, tgt_len, dropout=self.dropout
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_output = attn_output.reshape(bsz, q_len, -1)
|
attn_output = attn_output.reshape(bsz, tgt_len, -1)
|
||||||
attn_output = self.out_proj(attn_output)
|
attn_output = self.out_proj(attn_output)
|
||||||
|
|
||||||
if not output_attentions:
|
if not output_attentions:
|
||||||
@@ -614,15 +578,15 @@ class WhisperFlashAttention2(WhisperAttention):
|
|||||||
|
|
||||||
|
|
||||||
class WhisperSdpaAttention(WhisperAttention):
|
class WhisperSdpaAttention(WhisperAttention):
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with BART->whisper, Bart->Whisper
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
key_value_states: Optional[torch.Tensor] = None,
|
key_value_states: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
"""Input shape: Batch x Time x Channel"""
|
"""Input shape: Batch x Time x Channel"""
|
||||||
if output_attentions or layer_head_mask is not None:
|
if output_attentions or layer_head_mask is not None:
|
||||||
@@ -638,59 +602,50 @@ class WhisperSdpaAttention(WhisperAttention):
|
|||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
layer_head_mask=layer_head_mask,
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||||
# for the decoder
|
# for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
bsz, tgt_len, _ = hidden_states.size()
|
bsz, tgt_len, _ = hidden_states.size()
|
||||||
|
|
||||||
# get query proj
|
# get query proj
|
||||||
query_states = self.q_proj(hidden_states)
|
query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz)
|
||||||
# get key, value proj
|
|
||||||
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
|
if past_key_value is not None:
|
||||||
# is checking that the `sequence_length` of the `past_key_value` is the same as
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||||
# the provided `key_value_states` to support prefix tuning
|
if is_cross_attention:
|
||||||
if (
|
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||||
is_cross_attention
|
past_key_value.is_updated[self.layer_idx] = True
|
||||||
and past_key_value is not None
|
past_key_value = past_key_value.cross_attention_cache
|
||||||
and past_key_value[0].shape[2] == key_value_states.shape[1]
|
else:
|
||||||
):
|
past_key_value = past_key_value.self_attention_cache
|
||||||
|
|
||||||
|
# use key_value_states if cross attention
|
||||||
|
current_states = key_value_states if key_value_states is not None else hidden_states
|
||||||
|
if is_cross_attention and past_key_value and is_updated:
|
||||||
# reuse k,v, cross_attentions
|
# reuse k,v, cross_attentions
|
||||||
key_states = past_key_value[0]
|
key_states = past_key_value.key_cache[self.layer_idx]
|
||||||
value_states = past_key_value[1]
|
value_states = past_key_value.value_cache[self.layer_idx]
|
||||||
elif is_cross_attention:
|
|
||||||
# cross_attentions
|
|
||||||
key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
|
|
||||||
elif past_key_value is not None:
|
|
||||||
# reuse k, v, self_attention
|
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
else:
|
else:
|
||||||
# self_attention
|
key_states = self._shape(self.k_proj(current_states), -1, bsz)
|
||||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
value_states = self._shape(self.v_proj(current_states), -1, bsz)
|
||||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
if past_key_value is not None:
|
||||||
|
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||||
|
cache_position = cache_position if not is_cross_attention else None
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||||
|
)
|
||||||
|
|
||||||
if self.is_decoder:
|
causal_mask = attention_mask
|
||||||
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
if attention_mask is not None: # no matter the length, we just slice it
|
||||||
# Further calls to cross_attention layer can then reuse all cross-attention
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
|
||||||
# key/value_states (first "if" case)
|
|
||||||
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
|
||||||
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
|
||||||
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
|
||||||
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
|
||||||
past_key_value = (key_states, value_states)
|
|
||||||
|
|
||||||
query_states = self._shape(query_states, tgt_len, bsz)
|
|
||||||
|
|
||||||
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
# We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
|
||||||
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
|
||||||
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
# The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
|
||||||
is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
|
is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
|
||||||
|
|
||||||
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
# NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
|
||||||
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
# but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
|
||||||
@@ -698,7 +653,7 @@ class WhisperSdpaAttention(WhisperAttention):
|
|||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
attn_mask=attention_mask,
|
attn_mask=causal_mask,
|
||||||
dropout_p=self.dropout if self.training else 0.0,
|
dropout_p=self.dropout if self.training else 0.0,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
)
|
)
|
||||||
@@ -798,9 +753,8 @@ class WhisperEncoderLayer(nn.Module):
|
|||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper, MBART->WHISPER
|
|
||||||
class WhisperDecoderLayer(nn.Module):
|
class WhisperDecoderLayer(nn.Module):
|
||||||
def __init__(self, config: WhisperConfig):
|
def __init__(self, config: WhisperConfig, layer_idx: int = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embed_dim = config.d_model
|
self.embed_dim = config.d_model
|
||||||
|
|
||||||
@@ -810,6 +764,7 @@ class WhisperDecoderLayer(nn.Module):
|
|||||||
dropout=config.attention_dropout,
|
dropout=config.attention_dropout,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
is_causal=True,
|
is_causal=True,
|
||||||
|
layer_idx=layer_idx,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
self.dropout = config.dropout
|
self.dropout = config.dropout
|
||||||
@@ -822,6 +777,7 @@ class WhisperDecoderLayer(nn.Module):
|
|||||||
config.decoder_attention_heads,
|
config.decoder_attention_heads,
|
||||||
dropout=config.attention_dropout,
|
dropout=config.attention_dropout,
|
||||||
is_decoder=True,
|
is_decoder=True,
|
||||||
|
layer_idx=layer_idx,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
||||||
@@ -837,9 +793,10 @@ class WhisperDecoderLayer(nn.Module):
|
|||||||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||||||
layer_head_mask: Optional[torch.Tensor] = None,
|
layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
|
||||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[EncoderDecoderCache] = None,
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
use_cache: Optional[bool] = True,
|
use_cache: Optional[bool] = True,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -863,41 +820,35 @@ class WhisperDecoderLayer(nn.Module):
|
|||||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
|
||||||
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
|
||||||
# add present self-attn cache to positions 1,2 of present_key_value tuple
|
|
||||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
past_key_value=self_attn_past_key_value,
|
past_key_value=past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
layer_head_mask=layer_head_mask,
|
layer_head_mask=layer_head_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
# Cross-Attention Block
|
# Cross-Attention Block
|
||||||
cross_attn_present_key_value = None
|
|
||||||
cross_attn_weights = None
|
cross_attn_weights = None
|
||||||
if encoder_hidden_states is not None:
|
if encoder_hidden_states is not None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
hidden_states = self.encoder_attn_layer_norm(hidden_states)
|
||||||
|
|
||||||
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
|
|
||||||
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
|
||||||
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
key_value_states=encoder_hidden_states,
|
key_value_states=encoder_hidden_states,
|
||||||
attention_mask=encoder_attention_mask,
|
attention_mask=encoder_attention_mask,
|
||||||
layer_head_mask=cross_attn_layer_head_mask,
|
layer_head_mask=cross_attn_layer_head_mask,
|
||||||
past_key_value=cross_attn_past_key_value,
|
past_key_value=past_key_value,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
hidden_states = residual + hidden_states
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
# add cross-attn to positions 3,4 of present_key_value tuple
|
# add cross-attn to positions 1 of present_key_value tuple
|
||||||
present_key_value = present_key_value + cross_attn_present_key_value
|
present_key_value = (present_key_value, cross_attn_present_key_value)
|
||||||
|
|
||||||
# Fully Connected
|
# Fully Connected
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@@ -927,6 +878,8 @@ class WhisperPreTrainedModel(PreTrainedModel):
|
|||||||
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
|
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
@@ -1024,14 +977,18 @@ WHISPER_INPUTS_DOCSTRING = r"""
|
|||||||
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
||||||
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
||||||
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||||
|
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
|
||||||
|
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
|
||||||
|
when `config.use_cache=True`
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- An [`~cache_utils.EncoderDecoderCache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
|
||||||
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||||
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
|
||||||
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||||||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||||||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||||||
@@ -1051,6 +1008,9 @@ WHISPER_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
|
||||||
|
in the correct position and to infer the complete sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
|
WHISPER_ENCODER_INPUTS_DOCSTRING = r"""
|
||||||
@@ -1256,7 +1216,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model, self.padding_idx)
|
||||||
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
|
self.embed_positions = WhisperPositionalEmbedding(self.max_target_positions, config.d_model)
|
||||||
|
|
||||||
self.layers = nn.ModuleList([WhisperDecoderLayer(config) for _ in range(config.decoder_layers)])
|
self.layers = nn.ModuleList(
|
||||||
|
[WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
|
||||||
|
)
|
||||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||||
self._use_sdpa = config._attn_implementation == "sdpa"
|
self._use_sdpa = config._attn_implementation == "sdpa"
|
||||||
|
|
||||||
@@ -1286,6 +1248,7 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
|
cache_position=None,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1320,13 +1283,17 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
- 1 indicates the head is **not masked**,
|
- 1 indicates the head is **not masked**,
|
||||||
- 0 indicates the head is **masked**.
|
- 0 indicates the head is **masked**.
|
||||||
|
|
||||||
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
|
past_key_values (`EncoderDecoderCache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
Pre-computed hidden-states that can be used to speed up auto-regressive (sequential) decoding. There are
|
||||||
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of
|
four sets of pre-computed hidden-states: key and values states in the self-attention blocks (2) and
|
||||||
shape `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
in the cross-attention blocks (2). The `past_key_values` are returned when `use_cache=True` is passed or
|
||||||
|
when `config.use_cache=True`
|
||||||
|
|
||||||
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the
|
Two formats are allowed:
|
||||||
cross-attention blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
|
- An [`~cache_utils.EncoderDecoderCache`] instance;
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
|
||||||
|
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
|
||||||
|
|
||||||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those
|
||||||
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of
|
||||||
@@ -1344,6 +1311,9 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
for more detail.
|
for more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
|
||||||
|
cache in the correct position and to infer the complete sequence length.
|
||||||
"""
|
"""
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||||
output_hidden_states = (
|
output_hidden_states = (
|
||||||
@@ -1363,26 +1333,38 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
|
||||||
|
|
||||||
# past_key_values_length
|
|
||||||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
if self._use_flash_attention_2:
|
return_legacy_cache = False
|
||||||
# 2d mask is passed through the layers
|
return_self_attention_cache = False
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
if use_cache or past_key_values is not None:
|
||||||
elif self._use_sdpa and head_mask is None and not output_attentions:
|
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
|
||||||
# output_attentions=True & head_mask can not be supported when using SDPA.
|
return_self_attention_cache = True
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
||||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
elif not isinstance(past_key_values, EncoderDecoderCache):
|
||||||
)
|
return_legacy_cache = True
|
||||||
else:
|
logger.warning_once(
|
||||||
# 4d mask is passed through the layers
|
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. "
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
|
||||||
attention_mask, input_shape, inputs_embeds, past_key_values_length
|
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
|
||||||
|
)
|
||||||
|
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
if cache_position is not None:
|
||||||
|
past_key_values_length = cache_position[0]
|
||||||
|
elif past_key_values is not None:
|
||||||
|
past_key_values_length = past_key_values.get_seq_length()
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
# embed positions
|
# embed positions
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
positions = self.embed_positions(
|
positions = self.embed_positions(
|
||||||
@@ -1396,6 +1378,14 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
|
hidden_states = inputs_embeds + positions.to(inputs_embeds.device)
|
||||||
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
|
causal_mask = self._update_causal_mask(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds,
|
||||||
|
cache_position,
|
||||||
|
past_key_values.self_attention_cache if past_key_values is not None else None,
|
||||||
|
output_attentions,
|
||||||
|
)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
@@ -1406,7 +1396,6 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attns = () if output_attentions else None
|
all_self_attns = () if output_attentions else None
|
||||||
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
||||||
next_decoder_cache = () if use_cache else None
|
|
||||||
|
|
||||||
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
|
||||||
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
|
||||||
@@ -1424,13 +1413,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
if dropout_probability < self.layerdrop:
|
if dropout_probability < self.layerdrop:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
layer_outputs = self._gradient_checkpointing_func(
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
decoder_layer.__call__,
|
decoder_layer.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask,
|
causal_mask,
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
None, # encoder attention mask
|
None, # encoder attention mask
|
||||||
head_mask[idx] if head_mask is not None else None,
|
head_mask[idx] if head_mask is not None else None,
|
||||||
@@ -1438,25 +1425,24 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
None, # past_key_value
|
None, # past_key_value
|
||||||
output_attentions,
|
output_attentions,
|
||||||
use_cache,
|
use_cache,
|
||||||
|
cache_position,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
layer_outputs = decoder_layer(
|
layer_outputs = decoder_layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
attention_mask=attention_mask,
|
attention_mask=causal_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
|
||||||
cross_attn_layer_head_mask=(
|
cross_attn_layer_head_mask=(
|
||||||
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
|
||||||
),
|
),
|
||||||
past_key_value=past_key_value,
|
past_key_value=past_key_values if use_cache else None,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if use_cache:
|
|
||||||
next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attns += (layer_outputs[1],)
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
@@ -1468,7 +1454,11 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states += (hidden_states,)
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
next_cache = next_decoder_cache if use_cache else None
|
next_cache = past_key_values if use_cache else None
|
||||||
|
if return_self_attention_cache:
|
||||||
|
next_cache = past_key_values.self_attention_cache
|
||||||
|
if return_legacy_cache:
|
||||||
|
next_cache = past_key_values.to_legacy_cache()
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
@@ -1483,6 +1473,87 @@ class WhisperDecoder(WhisperPreTrainedModel):
|
|||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
|
||||||
|
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
|
||||||
|
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
|
||||||
|
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
|
||||||
|
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_length()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
|
||||||
|
if attention_mask.max() != 0:
|
||||||
|
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
|
"The bare Whisper Model outputting raw hidden-states without any specific head on top.",
|
||||||
@@ -1571,13 +1642,14 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||||
use_cache: Optional[bool] = None,
|
use_cache: Optional[bool] = None,
|
||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
|
||||||
r"""
|
r"""
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1637,6 +1709,7 @@ class WhisperModel(WhisperPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@@ -1704,7 +1777,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
|||||||
decoder_head_mask: Optional[torch.Tensor] = None,
|
decoder_head_mask: Optional[torch.Tensor] = None,
|
||||||
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
||||||
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None,
|
||||||
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None,
|
||||||
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
decoder_position_ids: Optional[Tuple[torch.LongTensor]] = None,
|
||||||
labels: Optional[torch.LongTensor] = None,
|
labels: Optional[torch.LongTensor] = None,
|
||||||
@@ -1712,6 +1785,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
|
||||||
r"""
|
r"""
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
@@ -1766,6 +1840,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
lm_logits = self.proj_out(outputs[0])
|
lm_logits = self.proj_out(outputs[0])
|
||||||
|
|
||||||
@@ -1800,14 +1875,19 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
|||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
|
cache_position=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
decoder_position_ids = None
|
decoder_position_ids = None
|
||||||
if decoder_attention_mask is not None:
|
if decoder_attention_mask is not None:
|
||||||
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
|
decoder_position_ids = (decoder_attention_mask.cumsum(-1) - 1).clamp(min=0)
|
||||||
|
|
||||||
|
past_length = 0
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if isinstance(past_key_values, EncoderDecoderCache):
|
||||||
|
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].shape[2]
|
||||||
|
|
||||||
# Some generation methods already pass only the last input ID
|
# Some generation methods already pass only the last input ID
|
||||||
if decoder_input_ids.shape[1] > past_length:
|
if decoder_input_ids.shape[1] > past_length:
|
||||||
@@ -1821,6 +1901,13 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
|||||||
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
|
if decoder_position_ids is not None and decoder_position_ids.shape[1] > decoder_input_ids.shape[1]:
|
||||||
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
decoder_position_ids = decoder_position_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_length, past_length + decoder_input_ids.shape[1], device=decoder_input_ids.device
|
||||||
|
)
|
||||||
|
elif use_cache:
|
||||||
|
cache_position = cache_position[-decoder_input_ids.shape[1] :]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
@@ -1828,6 +1915,7 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
|
|||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
"decoder_position_ids": decoder_position_ids,
|
"decoder_position_ids": decoder_position_ids,
|
||||||
|
"cache_position": cache_position,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -1914,6 +2002,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
|||||||
output_attentions: Optional[bool] = None,
|
output_attentions: Optional[bool] = None,
|
||||||
output_hidden_states: Optional[bool] = None,
|
output_hidden_states: Optional[bool] = None,
|
||||||
return_dict: Optional[bool] = None,
|
return_dict: Optional[bool] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -1968,6 +2057,9 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
|||||||
for more detail.
|
for more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the cache
|
||||||
|
in the correct position and to infer the complete sequence length.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
@@ -2019,6 +2111,7 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
|||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
|
cache_position=cache_position,
|
||||||
)
|
)
|
||||||
|
|
||||||
logits = self.proj_out(outputs[0])
|
logits = self.proj_out(outputs[0])
|
||||||
@@ -2049,10 +2142,15 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
encoder_outputs=None,
|
encoder_outputs=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
cache_position=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
past_length = 0
|
||||||
if past_key_values is not None:
|
if past_key_values is not None:
|
||||||
past_length = past_key_values[0][0].shape[2]
|
if isinstance(past_key_values, (Cache, EncoderDecoderCache)):
|
||||||
|
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].shape[2]
|
||||||
|
|
||||||
# Some generation methods already pass only the last input ID
|
# Some generation methods already pass only the last input ID
|
||||||
if input_ids.shape[1] > past_length:
|
if input_ids.shape[1] > past_length:
|
||||||
@@ -2063,12 +2161,18 @@ class WhisperForCausalLM(WhisperPreTrainedModel):
|
|||||||
|
|
||||||
input_ids = input_ids[:, remove_prefix_length:]
|
input_ids = input_ids[:, remove_prefix_length:]
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device)
|
||||||
|
elif use_cache:
|
||||||
|
cache_position = cache_position[-input_ids.shape[1] :]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"encoder_outputs": encoder_outputs,
|
"encoder_outputs": encoder_outputs,
|
||||||
"past_key_values": past_key_values,
|
"past_key_values": past_key_values,
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"use_cache": use_cache,
|
"use_cache": use_cache,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"cache_position": cache_position,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -37,6 +37,13 @@ class DynamicCache(metaclass=DummyObject):
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderDecoderCache(metaclass=DummyObject):
|
||||||
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class HQQQuantizedCache(metaclass=DummyObject):
|
class HQQQuantizedCache(metaclass=DummyObject):
|
||||||
_backends = ["torch"]
|
_backends = ["torch"]
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ if is_torch_available():
|
|||||||
ImageGPTForCausalImageModeling,
|
ImageGPTForCausalImageModeling,
|
||||||
SpeechEncoderDecoderModel,
|
SpeechEncoderDecoderModel,
|
||||||
)
|
)
|
||||||
from transformers.cache_utils import DynamicCache, QuantoQuantizedCache
|
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
BeamSampleDecoderOnlyOutput,
|
BeamSampleDecoderOnlyOutput,
|
||||||
BeamSampleEncoderDecoderOutput,
|
BeamSampleEncoderDecoderOutput,
|
||||||
@@ -1636,7 +1636,6 @@ class GenerationTesterMixin:
|
|||||||
|
|
||||||
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
config, input_ids, attention_mask = self._get_input_ids_and_config()
|
||||||
config.use_cache = True
|
config.use_cache = True
|
||||||
config.is_decoder = True
|
|
||||||
|
|
||||||
model = model_class(config).to(torch_device).eval()
|
model = model_class(config).to(torch_device).eval()
|
||||||
generation_kwargs = {
|
generation_kwargs = {
|
||||||
@@ -1652,15 +1651,21 @@ class GenerationTesterMixin:
|
|||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
legacy_results = model.generate(input_ids, attention_mask=attention_mask, **generation_kwargs)
|
||||||
set_seed(seed)
|
set_seed(seed)
|
||||||
|
if config.is_encoder_decoder:
|
||||||
|
cache_cls = EncoderDecoderCache
|
||||||
|
past_key_values = cache_cls(DynamicCache(), DynamicCache())
|
||||||
|
else:
|
||||||
|
cache_cls = DynamicCache
|
||||||
|
past_key_values = cache_cls()
|
||||||
new_results = model.generate(
|
new_results = model.generate(
|
||||||
input_ids, attention_mask=attention_mask, past_key_values=DynamicCache(), **generation_kwargs
|
input_ids, attention_mask=attention_mask, past_key_values=past_key_values, **generation_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
# The two sets of generated sequences must match, despite the cache format between forward passes being
|
||||||
# different
|
# different
|
||||||
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
self.assertListEqual(legacy_results.sequences.tolist(), new_results.sequences.tolist())
|
||||||
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
self.assertTrue(isinstance(legacy_results.past_key_values, tuple))
|
||||||
self.assertTrue(isinstance(new_results.past_key_values, DynamicCache))
|
self.assertTrue(isinstance(new_results.past_key_values, cache_cls))
|
||||||
|
|
||||||
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
# The contents of the two caches, when converted to the same format (in both directions!), must match
|
||||||
legacy_cache = legacy_results.past_key_values
|
legacy_cache = legacy_results.past_key_values
|
||||||
@@ -1675,7 +1680,7 @@ class GenerationTesterMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
new_cache = new_results.past_key_values
|
new_cache = new_results.past_key_values
|
||||||
legacy_cache_converted = DynamicCache.from_legacy_cache(legacy_results.past_key_values)
|
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
|
||||||
for layer_idx in range(len(new_cache)):
|
for layer_idx in range(len(new_cache)):
|
||||||
for kv_idx in range(len(new_cache[layer_idx])):
|
for kv_idx in range(len(new_cache[layer_idx])):
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
|
|||||||
@@ -1539,6 +1539,46 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi
|
|||||||
def test_longform_generate_multi_batch_cond_prev(self):
|
def test_longform_generate_multi_batch_cond_prev(self):
|
||||||
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
self._check_longform_generate_multi_batch(condition_on_prev_tokens=True)
|
||||||
|
|
||||||
|
def test_custom_4d_attention_mask(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
model = WhisperForConditionalGeneration(config).to(device=torch_device, dtype=torch.float32)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
(
|
||||||
|
input_ids,
|
||||||
|
position_ids,
|
||||||
|
input_ids_shared_prefix,
|
||||||
|
mask_shared_prefix,
|
||||||
|
position_ids_shared_prefix,
|
||||||
|
) = self._get_custom_4d_mask_test_data()
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
logits = model.forward(
|
||||||
|
decoder_input_ids=input_ids,
|
||||||
|
input_features=input_dict["input_features"],
|
||||||
|
decoder_position_ids=position_ids,
|
||||||
|
).logits
|
||||||
|
# logits.shape == torch.Size([3, 4, ...])
|
||||||
|
|
||||||
|
logits_shared_prefix = model(
|
||||||
|
decoder_input_ids=input_ids_shared_prefix,
|
||||||
|
input_features=input_dict["input_features"],
|
||||||
|
decoder_attention_mask=mask_shared_prefix,
|
||||||
|
decoder_position_ids=position_ids_shared_prefix,
|
||||||
|
)[0]
|
||||||
|
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
|
||||||
|
|
||||||
|
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
|
||||||
|
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
|
||||||
|
|
||||||
|
# comparing greedily-chosen tokens:
|
||||||
|
assert torch.equal(out_last_tokens.max(axis=1).indices, out_shared_prefix_last_tokens.max(axis=1).indices)
|
||||||
|
|
||||||
|
# comparing softmax-normalized logits:
|
||||||
|
normalized_0 = torch.nn.functional.softmax(out_last_tokens)
|
||||||
|
normalized_1 = torch.nn.functional.softmax(out_shared_prefix_last_tokens)
|
||||||
|
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@@ -2961,6 +3001,34 @@ class WhisperModelIntegrationTests(unittest.TestCase):
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
model.generate(**inputs, **gen_kwargs)
|
model.generate(**inputs, **gen_kwargs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_tiny_static_generation(self):
|
||||||
|
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
input_speech = self._load_datasamples(4)
|
||||||
|
input_features = processor(input_speech, return_tensors="pt", sampling_rate=16_000).input_features
|
||||||
|
input_features = input_features.to(torch_device)
|
||||||
|
eager_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||||
|
|
||||||
|
model.generation_config.cache_implementation = "static"
|
||||||
|
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
|
||||||
|
|
||||||
|
# compile the forward pass and assert equivalence
|
||||||
|
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||||
|
assert (eager_generated_ids == static_generated_ids).all()
|
||||||
|
|
||||||
|
# check the compiled graph can be re-used and that the cache is correctly reset
|
||||||
|
# reverse the ordering of the input features
|
||||||
|
permutation_idx = (
|
||||||
|
torch.arange(input_features.shape[0], 0, step=-1, dtype=torch.long, device=input_features.device) - 1
|
||||||
|
)
|
||||||
|
input_features = input_features[permutation_idx, ...]
|
||||||
|
static_generated_ids = model.generate(input_features, max_new_tokens=64)
|
||||||
|
# assert re-ordered generations match those from eager
|
||||||
|
assert (eager_generated_ids[permutation_idx, :] == static_generated_ids).all()
|
||||||
|
|
||||||
|
|
||||||
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
def prepare_whisper_encoder_inputs_dict(config, input_features, head_mask=None):
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
@@ -3564,6 +3632,10 @@ class WhisperStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin,
|
|||||||
config=config, input_ids=inputs_dict["input_ids"]
|
config=config, input_ids=inputs_dict["input_ids"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@unittest.skip(reason="Tested implicitly through the encoder-decoder tests")
|
||||||
|
def test_custom_4d_attention_mask(self):
|
||||||
|
pass
|
||||||
|
|
||||||
@unittest.skip(reason="Generate needs input ids")
|
@unittest.skip(reason="Generate needs input ids")
|
||||||
def test_generate_without_input_ids(self):
|
def test_generate_without_input_ids(self):
|
||||||
# generate only works with input ids for whisper
|
# generate only works with input ids for whisper
|
||||||
|
|||||||
Reference in New Issue
Block a user