Feed forward chunking others (#6365)
* Feed forward chunking for Distilbert & Albert * Added ff chunking for many other models * Change model signature * Added chunking for XLM * Cleaned up by removing some variables. * remove test_chunking flag Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
8
src/transformers/modeling_utils.py
Normal file → Executable file
8
src/transformers/modeling_utils.py
Normal file → Executable file
@@ -1519,7 +1519,7 @@ def prune_layer(
|
||||
|
||||
|
||||
def apply_chunking_to_forward(
|
||||
chunk_size: int, chunk_dim: int, forward_fn: Callable[..., torch.Tensor], *input_tensors
|
||||
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function chunks the :obj:`input_tensors` into smaller input tensor parts of size :obj:`chunk_size` over the
|
||||
@@ -1529,12 +1529,12 @@ def apply_chunking_to_forward(
|
||||
directly applying :obj:`forward_fn` to :obj:`input_tensors`.
|
||||
|
||||
Args:
|
||||
forward_fn (:obj:`Callable[..., torch.Tensor]`):
|
||||
The forward function of the model.
|
||||
chunk_size (:obj:`int`):
|
||||
The chunk size of a chunked tensor: :obj:`num_chunks = len(input_tensors[0]) / chunk_size`.
|
||||
chunk_dim (:obj:`int`):
|
||||
The dimension over which the :obj:`input_tensors` should be chunked.
|
||||
forward_fn (:obj:`Callable[..., torch.Tensor]`):
|
||||
The forward function of the model.
|
||||
input_tensors (:obj:`Tuple[torch.Tensor]`):
|
||||
The input tensors of ``forward_fn`` which will be chunked.
|
||||
Returns:
|
||||
@@ -1550,7 +1550,7 @@ def apply_chunking_to_forward(
|
||||
|
||||
# implement a chunked forward function
|
||||
def forward(self, hidden_states):
|
||||
return apply_chunking_to_forward(self.chunk_size_lm_head, self.seq_len_dim, self.forward_chunk, hidden_states)
|
||||
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
||||
"""
|
||||
|
||||
assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
|
||||
|
||||
Reference in New Issue
Block a user