From 9d37c56bab8f7f1f1aa0b65be039516072254e77 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 17 Jul 2020 16:17:42 +0200 Subject: [PATCH] [Reformer] - Cache hidden states and buckets to speed up inference (#5578) * fix merge rebase * add intermediate reformer code * save intermediate caching results * save intermediate * save intermediate results * save intermediate * upload next step * fix generate tests * make tests work * add named tuple output * Apply suggestions from code review * fix use_cache for False case * fix tensor to gpu * fix tensor to gpu * refactor * refactor and make style --- src/transformers/modeling_reformer.py | 674 ++++++++++++++++++++++---- src/transformers/modeling_xlnet.py | 6 +- tests/test_modeling_reformer.py | 105 +++- 3 files changed, 685 insertions(+), 100 deletions(-) diff --git a/src/transformers/modeling_reformer.py b/src/transformers/modeling_reformer.py index 1ccc04ffab..6beed9df7a 100644 --- a/src/transformers/modeling_reformer.py +++ b/src/transformers/modeling_reformer.py @@ -18,8 +18,10 @@ import logging import sys from collections import namedtuple +from dataclasses import dataclass from functools import reduce from operator import mul +from typing import List, Optional, Tuple import numpy as np import torch @@ -32,6 +34,7 @@ from .configuration_reformer import ReformerConfig from .file_utils import ( DUMMY_INPUTS, DUMMY_MASK, + ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable, @@ -80,7 +83,18 @@ ReformerOutput = namedtuple("ReformerOutput", ["hidden_states", "attn_output", " ReformerBackwardOutput = namedtuple( "ReformerBackwardOutput", ["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"] ) -ReformerEncoderOutput = namedtuple("ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions"]) +ReformerEncoderOutput = namedtuple( + "ReformerEncoderOutput", ["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"], +) + + +def _stable_argsort(vector, dim): + # this function scales the vector so that torch.argsort is stable. + # torch.argsort is not stable on its own + scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1) + scale_offset = scale_offset.expand(vector.shape) + scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim]) + return torch.argsort(scaled_vector, dim=dim) def _get_least_common_mult_chunk_len(config): @@ -100,6 +114,23 @@ def _get_least_common_mult_chunk_len(config): ) +def _get_min_chunk_len(config): + attn_types = config.attn_layers + attn_types_set = set(attn_types) + if len(attn_types_set) == 1 and attn_types[0] == "lsh": + return config.lsh_attn_chunk_length + elif len(attn_types_set) == 1 and attn_types[0] == "local": + return config.local_attn_chunk_length + elif len(attn_types_set) == 2 and attn_types_set == set(["lsh", "local"]): + return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length) + else: + raise NotImplementedError( + "Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {}. Select attn layer types from ['lsh', 'local'] only.".format( + config.attn_layers + ) + ) + + class AxialPositionEmbeddings(nn.Module): """Constructs axial position embeddings. Useful for very long input sequences to save memory and time. @@ -171,15 +202,23 @@ class AxialPositionEmbeddings(nn.Module): ) # compute how many columns are needed - required_pos_encodings_columns = -(-sequence_length // self.axial_pos_shape[1]) + max_position_id = position_ids.max().item() + required_pos_encodings_columns = -(-(max_position_id + 1) // self.axial_pos_shape[1]) # cut to columns that are needed position_encodings = torch.cat( [weight[:, :required_pos_encodings_columns] for weight in broadcasted_weights], dim=-1 ) - position_encodings = torch.reshape(position_encodings, (batch_size, -1, position_encodings.shape[-1]))[ - :, :sequence_length - ] + position_encodings = torch.reshape(position_encodings, (batch_size, -1, position_encodings.shape[-1])) + + # select correct position encodings + position_encodings = torch.cat( + [ + torch.index_select(position_encodings[i], 0, position_ids[i]).unsqueeze(0) + for i in range(batch_size) + ], + dim=0, + ) return position_encodings @@ -213,7 +252,7 @@ class ReformerEmbeddings(nn.Module): AxialPositionEmbeddings(config) if config.axial_pos_embds else PositionEmbeddings(config) ) - def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): + def forward(self, input_ids=None, position_ids=None, inputs_embeds=None, start_idx_pos_encodings=0): if input_ids is not None: input_shape = input_ids.size() device = input_ids.device @@ -223,7 +262,9 @@ class ReformerEmbeddings(nn.Module): seq_length = input_shape[1] if position_ids is None: - position_ids = torch.arange(seq_length, dtype=torch.long, device=device) + position_ids = torch.arange( + start_idx_pos_encodings, start_idx_pos_encodings + seq_length, dtype=torch.long, device=device + ) position_ids = position_ids.unsqueeze(0).expand(input_shape) if inputs_embeds is None: @@ -339,8 +380,10 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): attention_mask=None, head_mask=None, num_hashes=None, - output_attentions=False, buckets=None, + past_buckets_states=None, + use_cache=False, + output_attentions=False, **kwargs ): sequence_length = hidden_states.shape[1] @@ -349,18 +392,73 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # num hashes can optionally be overwritten by user num_hashes = num_hashes if num_hashes is not None else self.num_hashes - # project hidden_states to query_key and value - query_key_vectors = self.query_key(hidden_states) - value_vectors = self.value(hidden_states) + do_cached_attention = use_cache and past_buckets_states[1] is not None + + # check if cache shall be used and that hidden states are already cached + if do_cached_attention: + assert ( + sequence_length == 1 + ), f"At the moment, auto-regressive language generation is only possible one word at a time. Make sure that input sequence length {sequence_length} equals 1, when `past_buckets_states` is passed." + past_buckets = past_buckets_states[0] + past_states = past_buckets_states[1] + + # get query vector + query_vectors = self.query_key(hidden_states) + query_vectors = self._split_hidden_size_dim( + query_vectors, self.num_attention_heads, self.attention_head_size + ) + + if past_buckets is not None: + key_value_hidden_states, sorted_bucket_idx, buckets = self._get_relevant_hid_states_and_buckets( + query_vectors=query_vectors, + attention_mask=attention_mask, + num_hashes=num_hashes, + hidden_states=hidden_states, + past_states=past_states, + past_buckets=past_buckets, + ) + + query_key_vectors = self._query_per_attn_head(key_value_hidden_states) + value_vectors = self._value_per_attn_head(key_value_hidden_states) + + # split key & value vectors by num hashes to apply + # self attention on each separately + query_key_vectors = self._split_seq_length_dim_to( + query_key_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, + ) + value_vectors = self._split_seq_length_dim_to( + value_vectors, num_hashes, -1, self.num_attention_heads, self.attention_head_size, + ) + # repeat query vectors across hash dimension + query_vectors = query_vectors.unsqueeze(2).repeat(1, 1, num_hashes, 1, 1) + else: + key_value_hidden_states = torch.cat([past_states, hidden_states], dim=1) + + query_key_vectors = self.query_key(key_value_hidden_states) + value_vectors = self.value(key_value_hidden_states) + + else: + # project hidden_states to query_key and value + query_vectors = None + query_key_vectors = self.query_key(hidden_states) + value_vectors = self.value(hidden_states) + + # if query key is not already split + if not do_cached_attention or past_buckets is None: + query_key_vectors = self._split_hidden_size_dim( + query_key_vectors, self.num_attention_heads, self.attention_head_size + ) + value_vectors = self._split_hidden_size_dim( + value_vectors, self.num_attention_heads, self.attention_head_size + ) + + # cache buckets for next incremental decoding + if do_cached_attention and past_buckets is None and key_value_hidden_states.shape[1] >= self.chunk_length: + buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) # free memory del hidden_states - query_key_vectors = self._split_hidden_size_dim( - query_key_vectors, self.num_attention_heads, self.attention_head_size - ) - value_vectors = self._split_hidden_size_dim(value_vectors, self.num_attention_heads, self.attention_head_size) - assert ( query_key_vectors.shape[-1] == self.attention_head_size ), "last dim of query_key_vectors is {} but should be {}.".format( @@ -372,8 +470,11 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): value_vectors.shape[-1], self.attention_head_size ) + do_standard_self_attention = (sequence_length <= self.chunk_length) or ( + use_cache and past_buckets_states[1] is not None + ) # LSH attention only makes sense if chunked attention should be performed - if self.chunk_length < sequence_length: + if not do_standard_self_attention: # set `num_buckets` on the fly, recommended way to do it if self.num_buckets is None: self._set_num_buckets(sequence_length) @@ -382,6 +483,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): if buckets is None: # hash query key vectors into buckets buckets = self._hash_vectors(query_key_vectors, num_hashes, attention_mask) + else: + # make sure buckets has correct shape for LSH attention + buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes * sequence_length) assert ( int(buckets.shape[-1]) == num_hashes * sequence_length @@ -397,7 +501,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # cluster query key value vectors according to hashed buckets query_key_vectors = self._gather_by_expansion(query_key_vectors, sorted_bucket_idx_per_hash, num_hashes) value_vectors = self._gather_by_expansion(value_vectors, sorted_bucket_idx_per_hash, num_hashes) - query_key_vectors = self._split_seq_length_dim_to( query_key_vectors, -1, self.chunk_length, self.num_attention_heads, self.attention_head_size, ) @@ -409,6 +512,9 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): assert ( self.num_chunks_before == 0 and self.num_chunks_after == 0 ), "If `config.chunk_length` is `None`, make sure `config.num_chunks_after` and `config.num_chunks_before` are set to 0." + elif do_cached_attention and past_buckets is not None: + # use max sequence length + sorted_bucket_idx_per_hash = sorted_bucket_idx else: # get sequence length indices sorted_bucket_idx_per_hash = torch.arange(sequence_length, device=query_key_vectors.device).repeat( @@ -418,25 +524,33 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): # scale key vectors key_vectors = self._len_and_dim_norm(query_key_vectors) + # set query_vectors to query key vectors if LSH self attention + query_vectors = query_vectors if query_vectors is not None else query_key_vectors + + # free memory + del query_key_vectors + # get attention probs out_vectors, logits, attention_probs = self._attend( - query_vectors=query_key_vectors, + query_vectors=query_vectors, key_vectors=key_vectors, value_vectors=value_vectors, sorted_bucket_idx_per_hash=sorted_bucket_idx_per_hash, attention_mask=attention_mask, head_mask=head_mask, - sequence_length=sequence_length, + do_standard_self_attention=do_standard_self_attention, + do_cached_attention=do_cached_attention, ) # free memory - del query_key_vectors, key_vectors, value_vectors + del key_vectors, value_vectors # re-order out_vectors and logits - if self.chunk_length < sequence_length: + if not do_standard_self_attention: # sort clusters back to correct ordering out_vectors, logits = ReverseSort.apply(out_vectors, logits, sorted_bucket_idx, undo_sorted_bucket_idx) + if not do_standard_self_attention or (do_cached_attention and past_buckets is not None): # sum up all hash rounds if num_hashes > 1: out_vectors = self._split_seq_length_dim_to( @@ -466,9 +580,28 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): if output_attentions is False: attention_probs = () + if buckets is not None: + buckets = buckets.view(batch_size, self.num_attention_heads, num_hashes, -1) + return LSHSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs, buckets=buckets) - def _hash_vectors(self, vectors, num_hashes, attention_mask): + def _query_per_attn_head(self, hidden_states): + per_head_query_key = self.query_key.weight.reshape( + self.num_attention_heads, self.attention_head_size, self.hidden_size + ).transpose(-2, -1) + # only relevant for inference and no bias => we can use einsum here + query_key_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_query_key) + return query_key_vectors + + def _value_per_attn_head(self, hidden_states): + per_head_value = self.value.weight.reshape( + self.num_attention_heads, self.attention_head_size, self.hidden_size + ).transpose(-2, -1) + # only relevant for inference and no bias => we can use einsum here + value_vectors = torch.einsum("balh,ahr->balr", hidden_states, per_head_value) + return value_vectors + + def _hash_vectors(self, vectors, num_hashes, attention_mask, increase_num_buckets=False): batch_size = vectors.shape[0] # See https://arxiv.org/pdf/1509.02897.pdf @@ -514,7 +647,6 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): rotated_vectors_factor = rotated_vectors[..., cur_sum : cur_sum + (bucket_factor // 2)] cur_sum = cur_sum + bucket_factor // 2 rotated_vectors_factor = torch.cat([rotated_vectors_factor, -rotated_vectors_factor], dim=-1) - if buckets is None: buckets = torch.argmax(rotated_vectors_factor, dim=-1) else: @@ -522,7 +654,7 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): cur_product = cur_product * bucket_factor - if attention_mask is not None: + if attention_mask is not None and (attention_mask.sum().item() < batch_size * attention_mask.shape[-1]): # add an extra bucket for padding tokens only num_buckets = num_buckets + 1 # assign padding tokens extra bucket @@ -530,6 +662,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): buckets = torch.where( buckets_mask, buckets, torch.tensor(num_buckets - 1, dtype=torch.long, device=buckets.device) ) + elif increase_num_buckets: + num_buckets = num_buckets + 1 # buckets is now (Batch_size x Num_Attn_Heads x Num_Hashes x Seq_Len). # Next we add offsets so that bucket numbers from different hashing rounds don't overlap. @@ -545,20 +679,8 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(self, sequence_length, buckets, num_hashes): # no gradients are needed with torch.no_grad(): - batch_size = buckets.shape[0] - - # arange and expand - orig_indices = torch.arange(num_hashes * sequence_length, device=buckets.device).view(1, 1, -1) - orig_indices = orig_indices.expand(batch_size, self.num_attention_heads, orig_indices.shape[-1]) - - # scale buckets - scaled_buckets = sequence_length * buckets + (orig_indices % sequence_length) - - # remove gradient - scaled_buckets = scaled_buckets.detach() - - # Hash-based sort - sorted_bucket_idx = torch.argsort(scaled_buckets, dim=-1) + # hash-based sort + sorted_bucket_idx = _stable_argsort(buckets, dim=-1) # create simple indices to scatter to, to have undo sort indices = ( @@ -600,26 +722,37 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): sorted_bucket_idx_per_hash, attention_mask, head_mask, - sequence_length, + do_standard_self_attention, + do_cached_attention, ): - # look at previous and following chunks if chunked attention - if self.chunk_length < sequence_length: + if not do_standard_self_attention: key_vectors = self._look_adjacent(key_vectors, self.num_chunks_before, self.num_chunks_after) value_vectors = self._look_adjacent(value_vectors, self.num_chunks_before, self.num_chunks_after) # get logits and dots + # (BS, NumAttn, NumHash x NumChunk, Chunk_L x Hidden),(BS, NumAttn, NumHash x NumChunk, Chunk_L * (Num_bef + Num_aft + 1) x Hidden) -> (BS, NumAttn, NumHash x NumChunk, Chunk_L, Chunk_L * (1 + Num_bef + Num_aft)) query_key_dots = torch.matmul(query_vectors, key_vectors.transpose(-1, -2)) # free memory del query_vectors, key_vectors # if chunked attention split bucket idxs to query and key - if self.chunk_length < sequence_length: + if not do_standard_self_attention: query_bucket_idx = self._split_seq_length_dim_to( sorted_bucket_idx_per_hash, -1, self.chunk_length, self.num_attention_heads ) key_value_bucket_idx = self._look_adjacent(query_bucket_idx, self.num_chunks_before, self.num_chunks_after) + elif do_cached_attention and query_key_dots.ndim > 4: + key_value_bucket_idx = sorted_bucket_idx_per_hash + query_bucket_idx = ( + key_value_bucket_idx.new_ones(key_value_bucket_idx.shape[:-1] + (1,)) * key_value_bucket_idx.max() + ) + elif do_cached_attention and query_key_dots.ndim <= 4: + query_bucket_idx = (query_key_dots.shape[-1] - 1) * torch.ones_like(query_key_dots)[:, :, :, -1] + key_value_bucket_idx = torch.arange( + query_key_dots.shape[-1], dtype=torch.long, device=query_key_dots.device + )[None, None, :].expand(query_bucket_idx.shape[:2] + (-1,)) else: query_bucket_idx = key_value_bucket_idx = sorted_bucket_idx_per_hash @@ -631,15 +764,20 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): self_mask_value = self.self_mask_value_float32 mask_value = self.mask_value_float32 - mask = self._compute_attn_mask( - query_bucket_idx, key_value_bucket_idx, attention_mask, query_key_dots.shape, sequence_length - ) + if not do_cached_attention: + mask = self._compute_attn_mask( + query_bucket_idx, + key_value_bucket_idx, + attention_mask, + query_key_dots.shape, + do_standard_self_attention, + ) - if mask is not None: - query_key_dots = torch.where(mask, query_key_dots, mask_value) + if mask is not None: + query_key_dots = torch.where(mask, query_key_dots, mask_value) - # free memory - del mask + # free memory + del mask # Self mask is ALWAYS applied. # From the reformer paper (https://arxiv.org/pdf/2001.04451.pdf): @@ -682,19 +820,20 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): del value_vectors # merge chunk length - if self.chunk_length < sequence_length: + if out_vectors.ndim > 4: logits = logits.flatten(start_dim=2, end_dim=3).squeeze(-1) out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) return out_vectors, logits, attention_probs - def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dot_shape, sequence_length): - + def _compute_attn_mask( + self, query_indices, key_indices, attention_mask, query_key_dot_shape, do_standard_self_attention + ): # attention mask for LSH if attention_mask is not None: # if chunked attention, the attention mask has to correspond to LSH order attention_mask = attention_mask.to(torch.uint8)[:, None, :] - if sequence_length > self.chunk_length: + if not do_standard_self_attention: # expand attn_mask to fit with key_value_bucket_idx shape attention_mask = attention_mask[:, None, :] attention_mask = attention_mask.expand(query_indices.shape[:-1] + (-1,)) @@ -715,6 +854,102 @@ class LSHSelfAttention(nn.Module, EfficientAttentionMixin): return attention_mask + def _get_relevant_hid_states_and_buckets( + self, query_vectors, attention_mask, num_hashes, hidden_states, past_states, past_buckets + ): + # concat hidden states + hidden_states = torch.cat([past_states, hidden_states], dim=1) + + # batch_size hidden + batch_size = hidden_states.shape[0] + sequence_length = hidden_states.shape[1] + + # check if cached buckets include pad bucket + max_bucket = self.num_buckets if isinstance(self.num_buckets, int) else reduce(mul, self.num_buckets) + + # if pad bucket was cached => need to increase num buckets for caching + increase_num_buckets = past_buckets.max() > num_hashes * max_bucket - 1 + + # retrieve query buckets + query_buckets = self._hash_vectors( + query_vectors, num_hashes, attention_mask, increase_num_buckets=increase_num_buckets + ) + + # concat buckets + concat_buckets = torch.cat([past_buckets, query_buckets.unsqueeze(-1)], dim=-1) + + # hash-based sort + bucket_idx = _stable_argsort(concat_buckets, dim=-1) + + # bucket_idx has shape: BatchSize x NumAttnHeads x NumHashes x SequenceLength + assert bucket_idx.shape == ( + batch_size, + self.num_attention_heads, + num_hashes, + sequence_length, + ), f"bucket_idx should have shape {(batch_size, self.num_attention_heads, num_hashes, sequence_length)}, but has shape {bucket_idx.shape}." + + # find indices of new bucket indices + relevant_bucket_idx = (bucket_idx == (bucket_idx.shape[-1] - 1)).nonzero() + + # expand relevant bucket indices to its chunks + relevant_bucket_idx_chunk = self._expand_to_indices_in_relevant_chunk(relevant_bucket_idx, sequence_length) + relevant_bucket_idx_chunk = bucket_idx[tuple(relevant_bucket_idx_chunk.transpose(0, 1))] + + # adapt bucket_idx for batch and hidden states for index select + bucket_idx_batch_offset = sequence_length * ( + batch_size + * torch.arange(relevant_bucket_idx_chunk.shape[-1], device=hidden_states.device, dtype=torch.long) + // relevant_bucket_idx_chunk.shape[-1] + ) + + # add batch offset + relevant_bucket_idx_chunk_all_batch = relevant_bucket_idx_chunk + bucket_idx_batch_offset + hidden_states = hidden_states.reshape((-1, self.hidden_size)) + + # select all relevant hidden states + relevant_hidden_states = hidden_states.index_select(0, relevant_bucket_idx_chunk_all_batch) + + # reshape hidden states and bucket_idx to correct output + relevant_hidden_states = relevant_hidden_states.reshape( + batch_size, self.num_attention_heads, -1, self.hidden_size + ) + relevant_bucket_idx_chunk = relevant_bucket_idx_chunk.reshape( + batch_size, self.num_attention_heads, num_hashes, -1 + ) + + assert ( + relevant_hidden_states.shape[2] + == (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes + ), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length * num_hashes} `hidden_states`, there are {relevant_hidden_states.shape[2]} `hidden_states`." + + assert ( + relevant_bucket_idx_chunk.shape[-1] + == (self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length + ), f"There should be {(self.num_chunks_before + self.num_chunks_after + 1) * self.chunk_length} `hidden_states`, there are {relevant_bucket_idx_chunk.shape[-1]} `bucket_idx`." + + return relevant_hidden_states, relevant_bucket_idx_chunk, query_buckets + + def _expand_to_indices_in_relevant_chunk(self, indices, sequence_length): + # get relevant indices of where chunk starts and its size + start_indices_chunk = ((indices[:, -1] // self.chunk_length) - self.num_chunks_before) * self.chunk_length + total_chunk_size = self.chunk_length * (1 + self.num_chunks_before + self.num_chunks_after) + + # expand start indices and add correct chunk offset via arange + expanded_start_indices = start_indices_chunk.unsqueeze(-1).expand(indices.shape[0], total_chunk_size) + chunk_sequence_indices = expanded_start_indices + torch.arange( + total_chunk_size, device=indices.device, dtype=torch.long + ).unsqueeze(0).expand(indices.shape[0], total_chunk_size) + + # make sure that circular logic holds via % seq len + chunk_sequence_indices = chunk_sequence_indices.flatten() % sequence_length + + # expand indices and set indices correctly + indices = indices.unsqueeze(1).expand((indices.shape[0], total_chunk_size, -1)).flatten(0, 1).clone() + indices[:, -1] = chunk_sequence_indices + + return indices + def _len_and_dim_norm(self, vectors): """ length and attention head size dim normalization @@ -803,14 +1038,42 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): self.register_buffer("mask_value_float16", torch.tensor(-1e4)) self.register_buffer("mask_value_float32", torch.tensor(-1e9)) - def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False, **kwargs): + def forward( + self, + hidden_states, + attention_mask=None, + head_mask=None, + past_buckets_states=None, + use_cache=False, + output_attentions=False, + **kwargs + ): sequence_length = hidden_states.shape[1] batch_size = hidden_states.shape[0] - # project hidden_states to query, key and value - query_vectors = self.query(hidden_states) - key_vectors = self.key(hidden_states) - value_vectors = self.value(hidden_states) + # check if cache shall be used and that hidden states are already cached + if use_cache and past_buckets_states[1] is not None: + assert ( + past_buckets_states[0] is None + ), "LocalSelfAttention should not make use of `buckets`. There seems to be an error when caching hidden_states_and_buckets." + key_value_hidden_states = self._retrieve_relevant_hidden_states( + past_buckets_states[1], self.chunk_length, self.num_chunks_before + ) + key_value_hidden_states = torch.cat([key_value_hidden_states, hidden_states], dim=1) + + # only query vector for last token + query_vectors = self.query(hidden_states) + # compute key and value for relevant chunk + key_vectors = self.key(key_value_hidden_states) + value_vectors = self.value(key_value_hidden_states) + + # free memory + del key_value_hidden_states + else: + # project hidden_states to query, key and value + query_vectors = self.query(hidden_states) + key_vectors = self.key(hidden_states) + value_vectors = self.value(hidden_states) # split last dim into `config.num_attention_heads` and `config.attention_head_size` query_vectors = self._split_hidden_size_dim(query_vectors, self.num_attention_heads, self.attention_head_size) @@ -848,8 +1111,11 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): batch_size, self.num_attention_heads, 1 ) + # if one should do normal n^2 self-attention + do_standard_self_attention = sequence_length <= self.chunk_length + # if input should be chunked - if self.chunk_length < sequence_length: + if not do_standard_self_attention: # chunk vectors # B x Num_Attn_Head x Seq_Len // chunk_len x chunk_len x attn_head_size query_vectors = self._split_seq_length_dim_to( @@ -880,7 +1146,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): del query_vectors, key_vectors mask = self._compute_attn_mask( - query_indices, key_indices, attention_mask, query_key_dots.shape, sequence_length + query_indices, key_indices, attention_mask, query_key_dots.shape, do_standard_self_attention ) if mask is not None: @@ -916,7 +1182,7 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): del value_vectors # merge chunk length - if self.chunk_length < sequence_length: + if not do_standard_self_attention: out_vectors = out_vectors.flatten(start_dim=2, end_dim=3) assert out_vectors.shape == (batch_size, self.num_attention_heads, sequence_length, self.attention_head_size,) @@ -928,13 +1194,15 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): return LocalSelfAttentionOutput(hidden_states=out_vectors, attention_probs=attention_probs) - def _compute_attn_mask(self, query_indices, key_indices, attention_mask, query_key_dots_shape, sequence_length): + def _compute_attn_mask( + self, query_indices, key_indices, attention_mask, query_key_dots_shape, do_standard_self_attention + ): # chunk attention mask and look before and after if attention_mask is not None: attention_mask = attention_mask.to(torch.uint8)[:, None, :] - if self.chunk_length < sequence_length: + if not do_standard_self_attention: attention_mask = self._split_seq_length_dim_to(attention_mask, -1, self.chunk_length, 1) attention_mask = self._look_adjacent(attention_mask, self.num_chunks_before, self.num_chunks_after) # create attn_mask @@ -952,6 +1220,11 @@ class LocalSelfAttention(nn.Module, EfficientAttentionMixin): return attention_mask + @staticmethod + def _retrieve_relevant_hidden_states(previous_hidden_states, chunk_length, num_chunks_before): + start_position = ((previous_hidden_states.shape[1] // chunk_length) - num_chunks_before) * chunk_length + return previous_hidden_states[:, start_position:] + class ReformerSelfOutput(nn.Module): def __init__(self, config): @@ -999,21 +1272,31 @@ class ReformerAttention(nn.Module): attention_mask=None, head_mask=None, num_hashes=None, + past_buckets_states=None, + use_cache=False, + orig_sequence_length=None, output_attentions=False, buckets=None, ): hidden_states = self.layer_norm(hidden_states) + # make sure cached hidden states is set to None for backward pass + if past_buckets_states is not None: + past_buckets_states_layer = past_buckets_states[self.layer_id] + else: + past_buckets_states_layer = None + # use cached buckets for backprob if buckets not None for LSHSelfAttention self_attention_outputs = self.self_attention( hidden_states=hidden_states, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, + past_buckets_states=past_buckets_states_layer, + use_cache=use_cache, output_attentions=output_attentions, buckets=buckets, ) - attention_output = self.output(self_attention_outputs.hidden_states) # add buckets if necessary if hasattr(self_attention_outputs, "buckets"): @@ -1021,6 +1304,28 @@ class ReformerAttention(nn.Module): else: buckets = None + # cache hidden states for future use + if use_cache: + if past_buckets_states[self.layer_id][0] is None: + # padded input should not be cached + past_buckets = ( + buckets[:, :, :, :orig_sequence_length] + if (buckets is not None and orig_sequence_length > 1) + else buckets + ) + else: + past_buckets = torch.cat([past_buckets_states[self.layer_id][0], buckets], dim=-1) + + if past_buckets_states[self.layer_id][1] is None: + # padded input should not be cached + past_states = hidden_states[:, :orig_sequence_length] + else: + past_states = torch.cat([past_buckets_states[self.layer_id][1], hidden_states], dim=1) + + past_buckets_states[self.layer_id] = (past_buckets, past_states) + # compute attention feed forward output + attention_output = self.output(self_attention_outputs.hidden_states) + return AttentionOutput( hidden_states=attention_output, attention_probs=self_attention_outputs.attention_probs, buckets=buckets, ) @@ -1137,6 +1442,9 @@ class ReformerLayer(nn.Module): attention_mask=None, head_mask=None, num_hashes=None, + past_buckets_states=None, + use_cache=False, + orig_sequence_length=None, output_attentions=False, ): with torch.no_grad(): @@ -1149,6 +1457,9 @@ class ReformerLayer(nn.Module): head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + orig_sequence_length=orig_sequence_length, output_attentions=output_attentions, ) attn_output = attn_outputs.hidden_states @@ -1254,6 +1565,9 @@ class _ReversibleFunction(Function): num_hashes, all_hidden_states, all_attentions, + past_buckets_states, + use_cache, + orig_sequence_length, output_hidden_states, output_attentions, ): @@ -1262,7 +1576,7 @@ class _ReversibleFunction(Function): # split duplicated tensor hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1) - for layer, layer_head_mask in zip(layers, head_mask): + for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)): if output_hidden_states is True: all_hidden_states.append(hidden_states) @@ -1272,8 +1586,12 @@ class _ReversibleFunction(Function): attention_mask=attention_mask, head_mask=layer_head_mask, num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + orig_sequence_length=orig_sequence_length, output_attentions=output_attentions, ) + attn_output = layer_outputs.attn_output hidden_states = layer_outputs.hidden_states all_buckets = all_buckets + (layer_outputs.buckets,) @@ -1339,7 +1657,7 @@ class _ReversibleFunction(Function): # num of return vars has to match num of forward() args # return gradient for hidden_states arg and None for other args - return grad_hidden_states, None, None, None, None, None, None, None, None + return grad_hidden_states, None, None, None, None, None, None, None, None, None, None, None class ReformerEncoder(nn.Module): @@ -1358,6 +1676,9 @@ class ReformerEncoder(nn.Module): attention_mask=None, head_mask=None, num_hashes=None, + past_buckets_states=None, + use_cache=False, + orig_sequence_length=None, output_hidden_states=False, output_attentions=False, ): @@ -1365,6 +1686,10 @@ class ReformerEncoder(nn.Module): all_hidden_states = [] all_attentions = [] + # init cached hidden states if necessary + if past_buckets_states is None: + past_buckets_states = [((None), (None)) for i in range(len(self.layers))] + # concat same tensor for reversible ResNet hidden_states = torch.cat([hidden_states, hidden_states], dim=-1) hidden_states = _ReversibleFunction.apply( @@ -1375,6 +1700,9 @@ class ReformerEncoder(nn.Module): num_hashes, all_hidden_states, all_attentions, + past_buckets_states, + use_cache, + orig_sequence_length, output_hidden_states, output_attentions, ) @@ -1386,7 +1714,10 @@ class ReformerEncoder(nn.Module): hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) return ReformerEncoderOutput( - hidden_states=hidden_states, all_hidden_states=all_hidden_states, all_attentions=all_attentions + hidden_states=hidden_states, + all_hidden_states=all_hidden_states, + all_attentions=all_attentions, + past_buckets_states=past_buckets_states, ) @@ -1448,6 +1779,85 @@ class ReformerPreTrainedModel(PreTrainedModel): module.bias.data.zero_() +@dataclass +class ReformerModelOutput(ModelOutput): + """ + Output type of :class:`~transformers.ReformerModel`. + + Args: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, hidden_size)`): + Sequence of hidden-states at the last layer of the model. + + ``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then + ``num_predict`` corresponds to ``sequence_length``. + past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape + :obj:`(batch_size, num_heads, num_hashes, sequence_length)`) + and :obj:`tuple(1)` being the previous `hidden_states` of shape + :obj:`(batch_size, sequence_length, hidden_size)`). + + Contains pre-computed buckets and hidden-states that can be used (see + ``past_buckets_states`` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor + past_buckets_states: Optional[List[Tuple[torch.LongTensor, torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class ReformerModelWithLMHeadOutput(ModelOutput): + """ + Output type of :class:`~transformers.ReformerModelWithLMHead`. + + Args: + loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided) + Language modeling loss (for next-token prediction). + logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_predict, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + + ``num_predict`` corresponds to ``target_mapping.shape[1]``. If ``target_mapping`` is ``None``, then + ``num_predict`` corresponds to ``sequence_length``. + past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape + :obj:`(batch_size, num_heads, num_hashes, sequence_length)`) + and :obj:`tuple(1)` being the previous `hidden_states` of shape + :obj:`(batch_size, sequence_length, hidden_size)`). + + Contains pre-computed buckets and hidden-states that can be used (see + ``past_buckets_states`` input) to speed up sequential decoding. + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + loss: Optional[torch.FloatTensor] + logits: torch.FloatTensor + past_buckets_states: Optional[List[Tuple[torch.LongTensor, torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + REFORMER_START_DOCSTRING = r""" Reformer was proposed in `Reformer: The Efficient Transformer `__ by Nikita Kitaev, Ɓukasz Kaiser, Anselm Levskaya. @@ -1499,6 +1909,15 @@ REFORMER_INPUTS_DOCSTRING = r""" bucketing. Setting `num_hashes` overwrites the default `num_hashes` defined in `config.num_hashes`. For more information, see `num_hashes` in :class:`transformers.ReformerConfig`. + past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, defaults `None`): + List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape + :obj:`(batch_size, num_heads, num_hashes, sequence_length)`) + and :obj:`tuple(1)` being the previous `hidden_states` of shape + :obj:`(batch_size, sequence_length, hidden_size)`). + + List of tuples that contains all previous computed hidden states and buckets (only relevant for LSH Self-Attention). Can be used to speed up sequential decoding. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`None`): + If set to ``True``, the ``past_buckets_states`` of all attention layers are returned. output_attentions (:obj:`bool`, `optional`, defaults to :obj:`None`): If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail. output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`): @@ -1554,10 +1973,39 @@ class ReformerModel(ReformerPreTrainedModel): head_mask=None, inputs_embeds=None, num_hashes=None, + past_buckets_states=None, + use_cache=None, output_hidden_states=None, output_attentions=None, return_tuple=None, ): + r""" + Return: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs: + last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + past_buckets_states (:obj:`List[Tuple(torch.LongTensor, torch.FloatTensor)]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): + List of :obj:`tuple(torch.LongTensor, torch.FloatTensor` of length :obj:`config.n_layers`, with :obj:`tuple(0)` being the previous `buckets` of shape + :obj:`(batch_size, num_heads, num_hashes, sequence_length)`) + and :obj:`tuple(1)` being the previous `hidden_states` of shape + :obj:`(batch_size, sequence_length, hidden_size)`). + + Contains pre-computed buckets and hidden-states that can be used (see + ``past_buckets_states`` input) to speed up sequential decoding. + all_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + all_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + use_cache = use_cache if use_cache is not None else self.config.use_cache output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -1579,6 +2027,9 @@ class ReformerModel(ReformerPreTrainedModel): len(input_shape) == 2 ), "`input_ids` have be of shape `[batch_size, sequence_length]`, but got shape: {}".format(input_shape) + if past_buckets_states is not None: + assert not self.training, "`past_buckets_states` can only be used for inference, not for training`." + # prepare head mask head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers, is_attention_chunked=True) @@ -1587,8 +2038,12 @@ class ReformerModel(ReformerPreTrainedModel): # if needs padding least_common_mult_chunk_length = _get_least_common_mult_chunk_len(self.config) + min_chunk_length = _get_min_chunk_len(self.config) + must_pad_to_match_chunk_length = ( - input_shape[-1] % least_common_mult_chunk_length != 0 and input_shape[-1] > least_common_mult_chunk_length + input_shape[-1] % least_common_mult_chunk_length != 0 + and input_shape[-1] > min_chunk_length + and past_buckets_states is None ) if must_pad_to_match_chunk_length: @@ -1613,13 +2068,27 @@ class ReformerModel(ReformerPreTrainedModel): device=device, ) - embedding_output = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds) + # start index for postion encoding depends on incremental decoding + if past_buckets_states is not None: + start_idx_pos_encodings = past_buckets_states[0][1].shape[1] + else: + start_idx_pos_encodings = 0 + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + start_idx_pos_encodings=start_idx_pos_encodings, + ) encoder_outputs = self.encoder( hidden_states=embedding_output, head_mask=head_mask, attention_mask=attention_mask, num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, + orig_sequence_length=orig_sequence_length, output_hidden_states=output_hidden_states, output_attentions=output_attentions, ) @@ -1629,12 +2098,18 @@ class ReformerModel(ReformerPreTrainedModel): if must_pad_to_match_chunk_length: sequence_output = sequence_output[:, :orig_sequence_length] + past_buckets_states = encoder_outputs.past_buckets_states if use_cache else None hidden_states = encoder_outputs.all_hidden_states if output_hidden_states else None attentions = encoder_outputs.all_attentions if output_attentions else None if return_tuple: - return tuple(v for v in [sequence_output, hidden_states, attentions] if v is not None) - return BaseModelOutput(last_hidden_state=sequence_output, hidden_states=hidden_states, attentions=attentions) + return tuple(v for v in [sequence_output, past_buckets_states, hidden_states, attentions] if v is not None) + return ReformerModelOutput( + last_hidden_state=sequence_output, + past_buckets_states=past_buckets_states, + hidden_states=hidden_states, + attentions=attentions, + ) def _pad_to_mult_of_chunk_length( self, @@ -1659,13 +2134,9 @@ class ReformerModel(ReformerPreTrainedModel): # Extend `attention_mask` if attention_mask is not None: - attention_mask = torch.cat( - [ - attention_mask, - torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype,), - ], - dim=-1, - ) + pad_attention_mask = torch.zeros(input_shape[0], padding_length, device=device, dtype=attention_mask.dtype) + + attention_mask = torch.cat([attention_mask, pad_attention_mask], dim=-1) else: attention_mask = torch.cat( [ @@ -1698,7 +2169,14 @@ class ReformerModel(ReformerPreTrainedModel): class ReformerModelWithLMHead(ReformerPreTrainedModel): def __init__(self, config): super().__init__(config) - assert config.is_decoder, "If you want to use `ReformerLMHeadModel` make sure that `is_decoder=True`." + assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`." + assert ( + "local" not in self.config.attn_layers or config.local_num_chunks_after == 0 + ), f"If causal mask is enabled, make sure that `config.local_num_chunks_after` is set to 0 and not {config.local_num_chunks_after}." + assert ( + "lsh" not in self.config.attn_layers or config.lsh_num_chunks_after == 0 + ), f"If causal mask is enabled, make sure that `config.lsh_num_chunks_after` is set to 1 and not {config.lsh_num_chunks_after}." + self.reformer = ReformerModel(config) self.lm_head = ReformerOnlyLMHead(config) @@ -1726,10 +2204,12 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): head_mask=None, inputs_embeds=None, num_hashes=None, - labels=None, + past_buckets_states=None, + use_cache=None, output_hidden_states=None, output_attentions=None, return_tuple=None, + labels=None, ): r""" labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): @@ -1747,6 +2227,8 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, + past_buckets_states=past_buckets_states, + use_cache=use_cache, output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_tuple=return_tuple, @@ -1768,22 +2250,44 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): output = (logits,) + reformer_outputs[1:] return ((loss,) + output) if loss is not None else output - return CausalLMOutput( + return ReformerModelWithLMHeadOutput( loss=loss, logits=logits, + past_buckets_states=reformer_outputs.past_buckets_states, hidden_states=reformer_outputs.hidden_states, attentions=reformer_outputs.attentions, ) def prepare_inputs_for_generation(self, input_ids, past, **kwargs): - # TODO(PVP): Add smart caching - inputs_dict = {"input_ids": input_ids} + # only last token for inputs_ids if past is defined in kwargs + if past is not None: + input_ids = input_ids[:, -1:] + + inputs_dict = { + "input_ids": input_ids, + "past_buckets_states": past, + "use_cache": kwargs["use_cache"], + } if "num_hashes" in kwargs: inputs_dict["num_hashes"] = kwargs["num_hashes"] return inputs_dict + def _reorder_cache(self, past, beam_idx): + reord_past_buckets_states = [] + for layer_past in past: + # buckets + if layer_past[0] is not None: + reord_buckets = layer_past[0].index_select(0, beam_idx) + else: + reord_buckets = None + + # hidden states + reord_hidden_states = layer_past[1].index_select(0, beam_idx) + reord_past_buckets_states.append((reord_buckets, reord_hidden_states)) + return reord_past_buckets_states + @add_start_docstrings("""Reformer Model with a `language modeling` head on top. """, REFORMER_START_DOCSTRING) class ReformerForMaskedLM(ReformerPreTrainedModel): @@ -1839,6 +2343,7 @@ class ReformerForMaskedLM(ReformerPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, + use_cache=False, # no causal mask output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_tuple=return_tuple, @@ -2027,6 +2532,7 @@ class ReformerForQuestionAnswering(ReformerPreTrainedModel): head_mask=head_mask, inputs_embeds=inputs_embeds, num_hashes=num_hashes, + use_cache=False, # no causal mask output_hidden_states=output_hidden_states, output_attentions=output_attentions, return_tuple=return_tuple, diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 81527c27ba..f599522f04 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -600,7 +600,7 @@ class XLNetModelOutput(ModelOutput): @dataclass class XLNetLMHeadModelOutput(ModelOutput): """ - Output type of :class:`~transformers.XLNetModel`. + Output type of :class:`~transformers.XLNetLMHeadModel`. Args: loss (:obj:`torch.FloatTensor` of shape `(1,)`, `optional`, returned when ``labels`` is provided) @@ -637,7 +637,7 @@ class XLNetLMHeadModelOutput(ModelOutput): @dataclass class XLNetForSequenceClassificationOutput(ModelOutput): """ - Base class for outputs of sentence classification models. + Output type of :class:`~transformers.XLNetForSequenceClassification`. Args: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided): @@ -671,7 +671,7 @@ class XLNetForSequenceClassificationOutput(ModelOutput): @dataclass class XLNetForTokenClassificationOutput(ModelOutput): """ - Base class for outputs of token classification models. + Output type of :class:`~transformers.XLNetForTokenClassificationOutput`. Args: loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) : diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index fbe8425b82..b70ec98c8b 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -181,8 +181,8 @@ class ReformerModelTester: model = ReformerModel(config=config) model.to(torch_device) model.eval() - (sequence_output,) = model(input_ids, attention_mask=input_mask) - (sequence_output,) = model(input_ids) + sequence_output, _ = model(input_ids, attention_mask=input_mask) + sequence_output, _ = model(input_ids) result = { "sequence_output": sequence_output, @@ -193,17 +193,21 @@ class ReformerModelTester: ) def create_and_check_reformer_model_with_lm_backward(self, config, input_ids, input_mask, choice_labels): - model = ReformerModelWithLMHead(config=config) + config.is_decoder = False + config.lsh_num_chunks_after = 1 + model = ReformerForMaskedLM(config=config) model.to(torch_device) model.eval() loss = model(input_ids, attention_mask=input_mask, labels=input_ids)[0] loss.backward() def create_and_check_reformer_with_lm(self, config, input_ids, input_mask, choice_labels): + config.lsh_num_chunks_after = 0 + config.is_decoder = True model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.eval() - loss, prediction_scores = model(input_ids, attention_mask=input_mask, labels=input_ids) + loss, prediction_scores, _ = model(input_ids, attention_mask=input_mask, labels=input_ids) result = { "loss": loss, "prediction_scores": prediction_scores, @@ -332,9 +336,11 @@ class ReformerModelTester: config.hidden_dropout_prob = 0 config.local_attention_probs_dropout_prob = 0 config.lsh_attention_probs_dropout_prob = 0 + config.lsh_num_chunks_after = 1 + config.is_decoder = False torch.manual_seed(0) - model = ReformerModelWithLMHead(config=config) + model = ReformerForMaskedLM(config=config) model.to(torch_device) model.train() model.zero_grad() @@ -348,7 +354,7 @@ class ReformerModelTester: config.chunk_size_feed_forward = 1 torch.manual_seed(0) - model = ReformerModelWithLMHead(config=config) + model = ReformerForMaskedLM(config=config) model.to(torch_device) model.train() model.zero_grad() @@ -405,7 +411,22 @@ class ReformerModelTester: output = model(input_ids, attention_mask=input_mask)[0] self.parent.assertFalse(torch.isnan(output).any().item()) + def create_and_check_reformer_model_generate(self, config, input_ids, input_mask, choice_labels): + config.is_decoder = True + config.lsh_num_chunks_after = 0 + config.bos_token_id = 0 + config.eos_token_id = None + config.max_length = 20 + + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + output = model.generate() + self.parent.assertIsNotNone(output) + def create_and_check_reformer_model_fp16_generate(self, config, input_ids, input_mask, choice_labels): + config.is_decoder = True + config.lsh_num_chunks_after = 0 model = ReformerModelWithLMHead(config=config) model.to(torch_device) model.half() @@ -418,13 +439,15 @@ class ReformerModelTester: # force chunk length to be bigger than input_ids config.lsh_attn_chunk_length = 2 * input_ids.shape[-1] config.local_attn_chunk_length = 2 * input_ids.shape[-1] - model = ReformerModelWithLMHead(config=config) + config.lsh_num_chunks_after = 1 + config.is_decoder = False + model = ReformerForMaskedLM(config=config) model.to(torch_device) model.eval() output_logits = model(input_ids, attention_mask=input_mask)[0] self.parent.assertTrue(output_logits.shape[1] == input_ids.shape[-1]) - def create_and_check_longformer_for_question_answering(self, config, input_ids, input_mask, choice_labels): + def create_and_check_reformer_for_question_answering(self, config, input_ids, input_mask, choice_labels): model = ReformerForQuestionAnswering(config=config) model.to(torch_device) model.eval() @@ -440,6 +463,33 @@ class ReformerModelTester: self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) self.check_loss_output(result) + def create_and_check_past_buckets_states(self, config, input_ids, input_mask, choice_labels): + config.is_decoder = True + config.lsh_num_chunks_before = 1 + config.lsh_num_chunks_after = 0 + model = ReformerModelWithLMHead(config=config) + model.to(torch_device) + model.eval() + input_ids_first = input_ids[:, :-1] + input_ids_second = input_ids[:, -1:] + + # return saved cache + _, past_buckets_states = model(input_ids_first, use_cache=True) + + # calculate last output with and without cache + outputs_with_cache, _ = model(input_ids_second, past_buckets_states=past_buckets_states, use_cache=True) + outputs_without_cache = model(input_ids)[0][:, -1] + + # select random slice idx + random_slice_idx = torch.randint(outputs_without_cache.shape[-1], (1, 1), device=torch_device).item() + + # outputs should be similar within range + self.parent.assertTrue( + torch.allclose( + outputs_with_cache[:, 0, random_slice_idx], outputs_without_cache[:, random_slice_idx], atol=1e-2 + ) + ) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, choice_labels) = config_and_inputs @@ -509,6 +559,18 @@ class ReformerTesterMixin: config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_no_chunking(*config_and_inputs) + def test_reformer_qa_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_for_question_answering(*config_and_inputs) + + def test_reformer_cached_inference(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_past_buckets_states(*config_and_inputs) + + def test_reformer_cached_generate(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_reformer_model_generate(*config_and_inputs) + @slow def test_dropout_random_seed_is_changing(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() @@ -621,8 +683,8 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T "num_buckets": 2, "num_hashes": 4, "lsh_attn_chunk_length": 4, - "lsh_num_chunks_before": 2, - "lsh_num_chunks_after": 3, + "lsh_num_chunks_before": 1, + "lsh_num_chunks_after": 0, "chunk_size_lm_head": 5, "chunk_size_feed_forward": 6, "feed_forward_size": 32, @@ -636,7 +698,9 @@ class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.T "axial_pos_embds": True, "axial_pos_shape": [4, 8], "axial_pos_embds_dim": [16, 48], - "attn_layers": ["lsh", "lsh", "lsh", "lsh"], + # sanotheu + # "attn_layers": ["lsh", "lsh", "lsh", "lsh"], + "attn_layers": ["lsh"], "pad_token_id": 0, "eos_token_id": 2, "scope": None, @@ -1049,8 +1113,23 @@ class ReformerIntegrationTests(unittest.TestCase): output_ids = model.generate( input_ids, max_length=50, num_beams=4, early_stopping=True, do_sample=False, num_hashes=8 ) - output_text = tokenizer.decode(output_ids[0]) + output = tokenizer.decode(output_ids[0]) + self.assertEqual( - output_text, + output, "A few months later state expression in his ideas, at the first entrance. He was positively for an inst", ) + + @slow + def test_pretrained_generate_use_cache_equality(self): + model = ReformerModelWithLMHead.from_pretrained("google/reformer-crime-and-punishment").to(torch_device) + tokenizer = ReformerTokenizer.from_pretrained("google/reformer-crime-and-punishment") + model.eval() + input_ids = tokenizer.encode("A few months later", return_tensors="pt").to(torch_device) + output_ids_with_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=False) + output_ids_without_cache = model.generate(input_ids, max_length=130, num_hashes=8, use_cache=True) + + output_with_cache = tokenizer.decode(output_ids_with_cache[0]) + output_without_cache = tokenizer.decode(output_ids_without_cache[0]) + + self.assertEqual(output_with_cache, output_without_cache)