Reformer (#3351)
* first copy & past commit from Bert and morgans LSH code * add easy way to compare to trax original code * translate most of function * make trax lsh self attention deterministic with numpy seed + copy paste code * add same config * add same config * make layer init work * implemented hash_vectors function for lsh attention * continue reformer translation * hf LSHSelfAttentionLayer gives same output as trax layer * refactor code * refactor code * refactor code * refactor * refactor + add reformer config * delete bogus file * split reformer attention layer into two layers * save intermediate step * save intermediate step * make test work * add complete reformer block layer * finish reformer layer * implement causal and self mask * clean reformer test and refactor code * fix merge conflicts * fix merge conflicts * update init * fix device for GPU * fix chunk length init for tests * include morgans optimization * improve memory a bit * improve comment * factorize num_buckets * better testing parameters * make whole model work * make lm model work * add t5 copy paste tokenizer * add chunking feed forward * clean config * add improved assert statements * make tokenizer work * improve test * correct typo * extend config * add complexer test * add new axial position embeddings * add local block attention layer * clean tests * refactor * better testing * save intermediate progress * clean test file * make shorter input length work for model * allow variable input length * refactor * make forward pass for pretrained model work * add generation possibility * finish dropout and init * make style * refactor * add first version of RevNet Layers * make forward pass work and add convert file * make uploaded model forward pass work * make uploaded model forward pass work * refactor code * add namedtuples and cache buckets * correct head masks * refactor * made reformer more flexible * make style * remove set max length * add attention masks * fix up tests * fix lsh attention mask * make random seed optional for the moment * improve memory in reformer * add tests * make style * make sure masks work correctly * detach gradients * save intermediate * correct backprob through gather * make style * change back num hashes * rename to labels * fix rotation shape * fix detach * update * fix trainer * fix backward dropout * make reformer more flexible * fix conflict * fix * fix * add tests for fixed seed in reformer layer * fix trainer typo * fix typo in activations * add fp16 tests * add fp16 training * support fp16 * correct gradient bug in reformer * add fast gelu * re-add dropout for embedding dropout * better naming * better naming * renaming * finalize test branch * finalize tests * add more tests * finish tests * fix * fix type trainer * fix fp16 tests * fix tests * fix tests * fix tests * fix issue with dropout * fix dropout seeds * correct random seed on gpu * finalize random seed for dropout * finalize random seed for dropout * remove duplicate line * correct half precision bug * make style * refactor * refactor * docstring * remove sinusoidal position encodings for reformer * move chunking to modeling_utils * make style * clean config * make style * fix tests * fix auto tests * pretrained models * fix docstring * update conversion file * Update pretrained_models.rst * fix rst * fix rst * update copyright * fix test path * fix test path * fix small issue in test * include reformer in generation tests * add docs for axial position encoding * finish docs * Update convert_reformer_trax_checkpoint_to_pytorch.py * remove isort * include sams comments * remove wrong comment in utils * correct typos * fix typo * Update reformer.rst * applied morgans optimization * make style * make gpu compatible * remove bogus file * big test refactor * add example for chunking * fix typo * add to README
This commit is contained in:
committed by
GitHub
parent
877fc56410
commit
dca34695d0
@@ -13,8 +13,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Tuple
|
||||
@@ -175,7 +175,7 @@ class ModuleUtilsMixin:
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
return extended_attention_mask
|
||||
|
||||
def get_head_mask(self, head_mask, num_hidden_layers):
|
||||
def get_head_mask(self, head_mask, num_hidden_layers, is_attention_chunked=False):
|
||||
"""
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
@@ -189,6 +189,8 @@ class ModuleUtilsMixin:
|
||||
"""
|
||||
if head_mask is not None:
|
||||
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
||||
if is_attention_chunked is True:
|
||||
head_mask = head_mask.unsqueeze(-1)
|
||||
else:
|
||||
head_mask = [None] * num_hidden_layers
|
||||
|
||||
@@ -786,6 +788,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
attention_mask=None,
|
||||
decoder_start_token_id=None,
|
||||
use_cache=None,
|
||||
**model_specific_kwargs
|
||||
):
|
||||
r""" Generates sequences for models with a LM head. The method currently supports greedy decoding, beam-search decoding, sampling with temperature, sampling with top-k or nucleus sampling.
|
||||
|
||||
@@ -863,6 +866,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
use_cache: (`optional`) bool
|
||||
If `use_cache` is True, past key values are used to speed up decoding if applicable to model. Defaults to `True`.
|
||||
|
||||
model_specific_kwargs: (`optional`) dict
|
||||
Additional model specific kwargs will be forwarded to the `forward` function of the model.
|
||||
|
||||
Return:
|
||||
|
||||
output: `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`
|
||||
@@ -1116,6 +1122,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_specific_kwargs=model_specific_kwargs,
|
||||
)
|
||||
else:
|
||||
output = self._generate_no_beam_search(
|
||||
@@ -1138,6 +1145,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_outputs=encoder_outputs,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=use_cache,
|
||||
model_specific_kwargs=model_specific_kwargs,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -1163,6 +1171,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
):
|
||||
""" Generate sequences for each example without beam search (num_beams == 1).
|
||||
All returned sequence are generated independantly.
|
||||
@@ -1175,7 +1184,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
||||
)
|
||||
|
||||
outputs = self(**model_inputs)
|
||||
@@ -1288,6 +1297,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
encoder_outputs,
|
||||
attention_mask,
|
||||
use_cache,
|
||||
model_specific_kwargs,
|
||||
):
|
||||
""" Generate sequences for each example with beam search.
|
||||
"""
|
||||
@@ -1314,7 +1324,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
while cur_len < max_length:
|
||||
model_inputs = self.prepare_inputs_for_generation(
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache
|
||||
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_specific_kwargs
|
||||
)
|
||||
outputs = self(**model_inputs) # (batch_size * num_beams, cur_len, vocab_size)
|
||||
next_token_logits = outputs[0][:, -1, :] # (batch_size * num_beams, vocab_size)
|
||||
@@ -2087,3 +2097,66 @@ def prune_layer(layer, index, dim=None):
|
||||
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
|
||||
else:
|
||||
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
|
||||
|
||||
|
||||
def apply_chunking_to_forward(
|
||||
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension `chunk_dim`.
|
||||
It then applies a layer `forward_fn` to each chunk independently to save memory.
|
||||
If the `forward_fn` is independent across the `chunk_dim` this function will yield the
|
||||
same result as not applying it.
|
||||
|
||||
Args:
|
||||
chunk_size: int - the chunk size of a chunked tensor. `num_chunks` = `len(input_tensors[0]) / chunk_size`
|
||||
chunk_dim: int - the dimension over which the input_tensors should be chunked
|
||||
forward_fn: fn - the forward fn of the model
|
||||
input_tensors: tuple(torch.Tensor) - the input tensors of `forward_fn` which are chunked
|
||||
Returns:
|
||||
a Tensor with the same shape the foward_fn would have given if applied
|
||||
|
||||
|
||||
Examples::
|
||||
|
||||
# rename the usual forward() fn to forward_chunk()
|
||||
def forward_chunk(self, hidden_states):
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
# implement a chunked forward function
|
||||
def forward(self, hidden_states):
|
||||
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
|
||||
"""
|
||||
|
||||
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
|
||||
tensor_shape = input_tensors[0].shape
|
||||
assert all(
|
||||
input_tensor.shape == tensor_shape for input_tensor in input_tensors
|
||||
), "All input tenors have to be of the same shape"
|
||||
|
||||
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compability
|
||||
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
||||
assert num_args_in_forward_chunk_fn == len(
|
||||
input_tensors
|
||||
), "forward_chunk_fn expects {} arguments, but only {} input tensors are given".format(
|
||||
num_args_in_forward_chunk_fn, len(input_tensors)
|
||||
)
|
||||
|
||||
if chunk_size > 0:
|
||||
assert (
|
||||
input_tensors[0].shape[chunk_dim] % chunk_size == 0
|
||||
), "The dimension to be chunked {} has to be a multiple of the chunk size {}".format(
|
||||
input_tensors[0][chunk_dim], chunk_size
|
||||
)
|
||||
|
||||
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
||||
|
||||
# chunk input tensor into tuples
|
||||
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
||||
# apply forward fn to every tuple
|
||||
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
||||
# concatenate output at same dimension
|
||||
return torch.cat(output_chunks, dim=chunk_dim)
|
||||
|
||||
return forward_fn(*input_tensors)
|
||||
|
||||
Reference in New Issue
Block a user