From 3223d49354e41dfa44649a9829c7b09013ad096e Mon Sep 17 00:00:00 2001 From: Patrick Deutschmann Date: Thu, 25 Aug 2022 08:34:42 +0200 Subject: [PATCH] Add ONNX support for Longformer (#17176) * Implement ONNX support for Longformer Fix repo consistency check complaints Fix value mismatches Add pooler output for default model Increase validation atol to accommodate multiple-choice error Fix copies Fix chunking for longer sequence lengths Add future comment * Fix issue in mask_invalid_locations * Remove torch imports in configuration_longformer * Change config access to fix LED * Push opset version to support tril * Work in review comments (mostly style) * Add Longformer to ONNX tests --- docs/source/en/serialization.mdx | 1 + src/transformers/models/led/modeling_led.py | 87 +++++++++++++----- .../longformer/configuration_longformer.py | 77 +++++++++++++++- .../models/longformer/modeling_longformer.py | 89 +++++++++++++------ src/transformers/onnx/features.py | 9 ++ tests/onnx/test_onnx_v2.py | 1 + 6 files changed, 212 insertions(+), 52 deletions(-) diff --git a/docs/source/en/serialization.mdx b/docs/source/en/serialization.mdx index 0aacdf76f7..89b73df4f5 100644 --- a/docs/source/en/serialization.mdx +++ b/docs/source/en/serialization.mdx @@ -74,6 +74,7 @@ Ready-made configurations include the following architectures: - LayoutLM - LayoutLMv3 - LeViT +- Longformer - LongT5 - M2M100 - Marian diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 0837ac2bc4..ff79c0cad4 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -160,6 +160,8 @@ class LEDEncoderSelfAttention(nn.Module): self.one_sided_attn_window_size = attention_window // 2 + self.config = config + def forward( self, hidden_states, @@ -389,24 +391,48 @@ class LEDEncoderSelfAttention(nn.Module): return chunked_hidden_states @staticmethod - def _chunk(hidden_states, window_overlap): + def _chunk(hidden_states, window_overlap, onnx_export=False): """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + if not onnx_export: + # non-overlapping chunks of size = 2w + hidden_states = hidden_states.view( + hidden_states.size(0), + torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"), + window_overlap * 2, + hidden_states.size(2), + ) + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(hidden_states.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 - # non-overlapping chunks of size = 2w - hidden_states = hidden_states.view( - hidden_states.size(0), - hidden_states.size(1) // (window_overlap * 2), - window_overlap * 2, - hidden_states.size(2), - ) + chunk_stride = list(hidden_states.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) - # use `as_strided` to make the chunks overlap with an overlap size = window_overlap - chunk_size = list(hidden_states.size()) - chunk_size[1] = chunk_size[1] * 2 - 1 + # When exporting to ONNX, use this separate logic + if hidden_states.size(1) == window_overlap * 2: + # simplest case + return hidden_states.unsqueeze(1) + else: + # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export - chunk_stride = list(hidden_states.stride()) - chunk_stride[1] = chunk_stride[1] // 2 - return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) + # TODO replace this with + # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3) + # once `unfold` is supported + + chunk_size = [ + hidden_states.size(0), + hidden_states.size(1) // window_overlap - 1, + window_overlap * 2, + hidden_states.size(2), + ] + + overlapping_chunks = torch.empty(chunk_size) + for chunk in range(chunk_size[1]): + overlapping_chunks[:, chunk, :, :] = hidden_states[ + :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, : + ] + return overlapping_chunks @staticmethod def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: @@ -415,10 +441,14 @@ class LEDEncoderSelfAttention(nn.Module): ending_mask = beginning_mask.flip(dims=(1, 3)) beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_mask = beginning_mask.expand(beginning_input.size()) - beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like( + beginning_input, -float("inf") + ).where(beginning_mask.bool(), beginning_input) ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_mask = ending_mask.expand(ending_input.size()) - ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like( + ending_input, -float("inf") + ).where(ending_mask.bool(), ending_input) def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): """ @@ -432,14 +462,14 @@ class LEDEncoderSelfAttention(nn.Module): ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" assert query.size() == key.size() - chunks_count = seq_len // window_overlap - 1 + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) - query = self._chunk(query, window_overlap) - key = self._chunk(key, window_overlap) + query = self._chunk(query, window_overlap, self.config.__dict__.get("onnx_export", False)) + key = self._chunk(key, window_overlap, self.config.__dict__.get("onnx_export", False)) # matrix multiplication # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim @@ -457,7 +487,7 @@ class LEDEncoderSelfAttention(nn.Module): # window_overlap previous words). The following column is attention score from each word to itself, then # followed by window_overlap columns for the upper triangle. - diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( + diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros( (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) ) @@ -498,11 +528,14 @@ class LEDEncoderSelfAttention(nn.Module): assert seq_len % (window_overlap * 2) == 0 assert attn_probs.size()[:3] == value.size()[:3] assert attn_probs.size(3) == 2 * window_overlap + 1 - chunks_count = seq_len // window_overlap - 1 + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap chunked_attn_probs = attn_probs.transpose(1, 2).reshape( - batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 + batch_size * num_heads, + torch.div(seq_len, window_overlap, rounding_mode="trunc"), + window_overlap, + 2 * window_overlap + 1, ) # group batch_size and num_heads dimensions into one @@ -577,9 +610,12 @@ class LEDEncoderSelfAttention(nn.Module): # (batch_size, seq_len, num_heads, max_num_global_attn_indices) attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) attn_probs_from_global_key[ - is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : ] = torch.finfo(attn_probs_from_global_key.dtype).min + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) return attn_probs_from_global_key @@ -673,9 +709,12 @@ class LEDEncoderSelfAttention(nn.Module): global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + global_attn_scores = global_attn_scores.transpose(1, 2) global_attn_scores[ - is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : ] = torch.finfo(global_attn_scores.dtype).min + global_attn_scores = global_attn_scores.transpose(1, 2) global_attn_scores = global_attn_scores.masked_fill( is_index_masked[:, None, None, :], diff --git a/src/transformers/models/longformer/configuration_longformer.py b/src/transformers/models/longformer/configuration_longformer.py index 53ceeafb64..977ca3e639 100644 --- a/src/transformers/models/longformer/configuration_longformer.py +++ b/src/transformers/models/longformer/configuration_longformer.py @@ -13,12 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. """ Longformer configuration""" -from typing import List, Union +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union -from ...utils import logging +from ...onnx import OnnxConfig +from ...utils import TensorType, logging from ..roberta.configuration_roberta import RobertaConfig +if TYPE_CHECKING: + from ...configuration_utils import PretrainedConfig + from ...onnx.config import PatchingSpec + from ...tokenization_utils_base import PreTrainedTokenizerBase + + logger = logging.get_logger(__name__) LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -71,6 +79,69 @@ class LongformerConfig(RobertaConfig): ```""" model_type = "longformer" - def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs): + def __init__( + self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, onnx_export: bool = False, **kwargs + ): super().__init__(sep_token_id=sep_token_id, **kwargs) self.attention_window = attention_window + self.onnx_export = onnx_export + + +class LongformerOnnxConfig(OnnxConfig): + def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: "List[PatchingSpec]" = None): + super().__init__(config, task, patching_specs) + config.onnx_export = True + + @property + def inputs(self) -> Mapping[str, Mapping[int, str]]: + if self.task == "multiple-choice": + dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"} + else: + dynamic_axis = {0: "batch", 1: "sequence"} + return OrderedDict( + [ + ("input_ids", dynamic_axis), + ("attention_mask", dynamic_axis), + ("global_attention_mask", dynamic_axis), + ] + ) + + @property + def outputs(self) -> Mapping[str, Mapping[int, str]]: + outputs = super().outputs + if self.task == "default": + outputs["pooler_output"] = {0: "batch"} + return outputs + + @property + def atol_for_validation(self) -> float: + """ + What absolute tolerance value to use during model conversion validation. + + Returns: + Float absolute tolerance value. + """ + return 1e-4 + + @property + def default_onnx_opset(self) -> int: + # needs to be >= 14 to support tril operator + return max(super().default_onnx_opset, 14) + + def generate_dummy_inputs( + self, + tokenizer: "PreTrainedTokenizerBase", + batch_size: int = -1, + seq_length: int = -1, + is_pair: bool = False, + framework: Optional[TensorType] = None, + ) -> Mapping[str, Any]: + inputs = super().generate_dummy_inputs( + preprocessor=tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework + ) + import torch + + inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"]) + # make every second token global + inputs["global_attention_mask"][:, ::2] = 1 + return inputs diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 7661f90bfb..00cd227a68 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -532,6 +532,8 @@ class LongformerSelfAttention(nn.Module): self.one_sided_attn_window_size = attention_window // 2 + self.config = config + def forward( self, hidden_states, @@ -761,24 +763,48 @@ class LongformerSelfAttention(nn.Module): return chunked_hidden_states @staticmethod - def _chunk(hidden_states, window_overlap): + def _chunk(hidden_states, window_overlap, onnx_export=False): """convert into overlapping chunks. Chunk size = 2w, overlap size = w""" + if not onnx_export: + # non-overlapping chunks of size = 2w + hidden_states = hidden_states.view( + hidden_states.size(0), + torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"), + window_overlap * 2, + hidden_states.size(2), + ) + # use `as_strided` to make the chunks overlap with an overlap size = window_overlap + chunk_size = list(hidden_states.size()) + chunk_size[1] = chunk_size[1] * 2 - 1 - # non-overlapping chunks of size = 2w - hidden_states = hidden_states.view( - hidden_states.size(0), - hidden_states.size(1) // (window_overlap * 2), - window_overlap * 2, - hidden_states.size(2), - ) + chunk_stride = list(hidden_states.stride()) + chunk_stride[1] = chunk_stride[1] // 2 + return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) - # use `as_strided` to make the chunks overlap with an overlap size = window_overlap - chunk_size = list(hidden_states.size()) - chunk_size[1] = chunk_size[1] * 2 - 1 + # When exporting to ONNX, use this separate logic + if hidden_states.size(1) == window_overlap * 2: + # simplest case + return hidden_states.unsqueeze(1) + else: + # have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export - chunk_stride = list(hidden_states.stride()) - chunk_stride[1] = chunk_stride[1] // 2 - return hidden_states.as_strided(size=chunk_size, stride=chunk_stride) + # TODO replace this with + # > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3) + # once `unfold` is supported + + chunk_size = [ + hidden_states.size(0), + hidden_states.size(1) // window_overlap - 1, + window_overlap * 2, + hidden_states.size(2), + ] + + overlapping_chunks = torch.empty(chunk_size) + for chunk in range(chunk_size[1]): + overlapping_chunks[:, chunk, :, :] = hidden_states[ + :, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, : + ] + return overlapping_chunks @staticmethod def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor: @@ -787,10 +813,14 @@ class LongformerSelfAttention(nn.Module): ending_mask = beginning_mask.flip(dims=(1, 3)) beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] beginning_mask = beginning_mask.expand(beginning_input.size()) - beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like( + beginning_input, -float("inf") + ).where(beginning_mask.bool(), beginning_input) ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] ending_mask = ending_mask.expand(ending_input.size()) - ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8 + input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like( + ending_input, -float("inf") + ).where(ending_mask.bool(), ending_input) def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int): """ @@ -804,14 +834,14 @@ class LongformerSelfAttention(nn.Module): ), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}" assert query.size() == key.size() - chunks_count = seq_len // window_overlap - 1 + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2 query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim) - query = self._chunk(query, window_overlap) - key = self._chunk(key, window_overlap) + query = self._chunk(query, window_overlap, self.config.__dict__.get("onnx_export", False)) + key = self._chunk(key, window_overlap, self.config.__dict__.get("onnx_export", False)) # matrix multiplication # bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim @@ -829,7 +859,7 @@ class LongformerSelfAttention(nn.Module): # window_overlap previous words). The following column is attention score from each word to itself, then # followed by window_overlap columns for the upper triangle. - diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty( + diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros( (batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1) ) @@ -870,11 +900,14 @@ class LongformerSelfAttention(nn.Module): assert seq_len % (window_overlap * 2) == 0 assert attn_probs.size()[:3] == value.size()[:3] assert attn_probs.size(3) == 2 * window_overlap + 1 - chunks_count = seq_len // window_overlap - 1 + chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1 # group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap chunked_attn_probs = attn_probs.transpose(1, 2).reshape( - batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1 + batch_size * num_heads, + torch.div(seq_len, window_overlap, rounding_mode="trunc"), + window_overlap, + 2 * window_overlap + 1, ) # group batch_size and num_heads dimensions into one @@ -949,9 +982,12 @@ class LongformerSelfAttention(nn.Module): # (batch_size, seq_len, num_heads, max_num_global_attn_indices) attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global)) + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) attn_probs_from_global_key[ - is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1] + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : ] = torch.finfo(attn_probs_from_global_key.dtype).min + attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3) return attn_probs_from_global_key @@ -1045,9 +1081,12 @@ class LongformerSelfAttention(nn.Module): global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len) + # need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets + global_attn_scores = global_attn_scores.transpose(1, 2) global_attn_scores[ - is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], : + is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, : ] = torch.finfo(global_attn_scores.dtype).min + global_attn_scores = global_attn_scores.transpose(1, 2) global_attn_scores = global_attn_scores.masked_fill( is_index_masked[:, None, None, :], @@ -1588,7 +1627,7 @@ class LongformerModel(LongformerPreTrainedModel): inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2) attention_mask = nn.functional.pad( - attention_mask, (0, padding_len), value=False + attention_mask, (0, padding_len), value=0 ) # no attention on the padding tokens token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0 diff --git a/src/transformers/onnx/features.py b/src/transformers/onnx/features.py index 3596fe1840..3f18c36983 100644 --- a/src/transformers/onnx/features.py +++ b/src/transformers/onnx/features.py @@ -358,6 +358,15 @@ class FeaturesManager: "seq2seq-lm-with-past", onnx_config_cls="models.longt5.LongT5OnnxConfig", ), + "longformer": supported_features_mapping( + "default", + "masked-lm", + "multiple-choice", + "question-answering", + "sequence-classification", + "token-classification", + onnx_config_cls="models.longformer.LongformerOnnxConfig", + ), "marian": supported_features_mapping( "default", "default-with-past", diff --git a/tests/onnx/test_onnx_v2.py b/tests/onnx/test_onnx_v2.py index 829c7ec0a4..52ced984ca 100644 --- a/tests/onnx/test_onnx_v2.py +++ b/tests/onnx/test_onnx_v2.py @@ -212,6 +212,7 @@ PYTORCH_EXPORT_MODELS = { ("data2vec-vision", "facebook/data2vec-vision-base"), ("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")), ("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)), + ("longformer", "allenai/longformer-base-4096"), ("yolos", "hustvl/yolos-tiny"), }