Add ONNX support for Longformer (#17176)
* Implement ONNX support for Longformer Fix repo consistency check complaints Fix value mismatches Add pooler output for default model Increase validation atol to accommodate multiple-choice error Fix copies Fix chunking for longer sequence lengths Add future comment * Fix issue in mask_invalid_locations * Remove torch imports in configuration_longformer * Change config access to fix LED * Push opset version to support tril * Work in review comments (mostly style) * Add Longformer to ONNX tests
This commit is contained in:
committed by
GitHub
parent
c55d6e4e10
commit
3223d49354
@@ -74,6 +74,7 @@ Ready-made configurations include the following architectures:
|
|||||||
- LayoutLM
|
- LayoutLM
|
||||||
- LayoutLMv3
|
- LayoutLMv3
|
||||||
- LeViT
|
- LeViT
|
||||||
|
- Longformer
|
||||||
- LongT5
|
- LongT5
|
||||||
- M2M100
|
- M2M100
|
||||||
- Marian
|
- Marian
|
||||||
|
|||||||
@@ -160,6 +160,8 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
|
|
||||||
self.one_sided_attn_window_size = attention_window // 2
|
self.one_sided_attn_window_size = attention_window // 2
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -389,17 +391,16 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
return chunked_hidden_states
|
return chunked_hidden_states
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _chunk(hidden_states, window_overlap):
|
def _chunk(hidden_states, window_overlap, onnx_export=False):
|
||||||
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
|
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
|
||||||
|
if not onnx_export:
|
||||||
# non-overlapping chunks of size = 2w
|
# non-overlapping chunks of size = 2w
|
||||||
hidden_states = hidden_states.view(
|
hidden_states = hidden_states.view(
|
||||||
hidden_states.size(0),
|
hidden_states.size(0),
|
||||||
hidden_states.size(1) // (window_overlap * 2),
|
torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"),
|
||||||
window_overlap * 2,
|
window_overlap * 2,
|
||||||
hidden_states.size(2),
|
hidden_states.size(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
|
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
|
||||||
chunk_size = list(hidden_states.size())
|
chunk_size = list(hidden_states.size())
|
||||||
chunk_size[1] = chunk_size[1] * 2 - 1
|
chunk_size[1] = chunk_size[1] * 2 - 1
|
||||||
@@ -408,6 +409,31 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
chunk_stride[1] = chunk_stride[1] // 2
|
chunk_stride[1] = chunk_stride[1] // 2
|
||||||
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
|
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
|
||||||
|
|
||||||
|
# When exporting to ONNX, use this separate logic
|
||||||
|
if hidden_states.size(1) == window_overlap * 2:
|
||||||
|
# simplest case
|
||||||
|
return hidden_states.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
|
||||||
|
|
||||||
|
# TODO replace this with
|
||||||
|
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
|
||||||
|
# once `unfold` is supported
|
||||||
|
|
||||||
|
chunk_size = [
|
||||||
|
hidden_states.size(0),
|
||||||
|
hidden_states.size(1) // window_overlap - 1,
|
||||||
|
window_overlap * 2,
|
||||||
|
hidden_states.size(2),
|
||||||
|
]
|
||||||
|
|
||||||
|
overlapping_chunks = torch.empty(chunk_size)
|
||||||
|
for chunk in range(chunk_size[1]):
|
||||||
|
overlapping_chunks[:, chunk, :, :] = hidden_states[
|
||||||
|
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
|
||||||
|
]
|
||||||
|
return overlapping_chunks
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
|
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
|
||||||
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
|
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
|
||||||
@@ -415,10 +441,14 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
ending_mask = beginning_mask.flip(dims=(1, 3))
|
ending_mask = beginning_mask.flip(dims=(1, 3))
|
||||||
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
|
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
|
||||||
beginning_mask = beginning_mask.expand(beginning_input.size())
|
beginning_mask = beginning_mask.expand(beginning_input.size())
|
||||||
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
|
input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like(
|
||||||
|
beginning_input, -float("inf")
|
||||||
|
).where(beginning_mask.bool(), beginning_input)
|
||||||
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
|
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
|
||||||
ending_mask = ending_mask.expand(ending_input.size())
|
ending_mask = ending_mask.expand(ending_input.size())
|
||||||
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
|
input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like(
|
||||||
|
ending_input, -float("inf")
|
||||||
|
).where(ending_mask.bool(), ending_input)
|
||||||
|
|
||||||
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
|
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
|
||||||
"""
|
"""
|
||||||
@@ -432,14 +462,14 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
|
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
|
||||||
assert query.size() == key.size()
|
assert query.size() == key.size()
|
||||||
|
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
|
||||||
|
|
||||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
|
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
|
||||||
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
||||||
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
||||||
|
|
||||||
query = self._chunk(query, window_overlap)
|
query = self._chunk(query, window_overlap, self.config.__dict__.get("onnx_export", False))
|
||||||
key = self._chunk(key, window_overlap)
|
key = self._chunk(key, window_overlap, self.config.__dict__.get("onnx_export", False))
|
||||||
|
|
||||||
# matrix multiplication
|
# matrix multiplication
|
||||||
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
|
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
|
||||||
@@ -457,7 +487,7 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
# window_overlap previous words). The following column is attention score from each word to itself, then
|
# window_overlap previous words). The following column is attention score from each word to itself, then
|
||||||
# followed by window_overlap columns for the upper triangle.
|
# followed by window_overlap columns for the upper triangle.
|
||||||
|
|
||||||
diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
|
diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
|
||||||
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
|
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -498,11 +528,14 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
assert seq_len % (window_overlap * 2) == 0
|
assert seq_len % (window_overlap * 2) == 0
|
||||||
assert attn_probs.size()[:3] == value.size()[:3]
|
assert attn_probs.size()[:3] == value.size()[:3]
|
||||||
assert attn_probs.size(3) == 2 * window_overlap + 1
|
assert attn_probs.size(3) == 2 * window_overlap + 1
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
|
||||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
|
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
|
||||||
|
|
||||||
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
|
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
|
||||||
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1
|
batch_size * num_heads,
|
||||||
|
torch.div(seq_len, window_overlap, rounding_mode="trunc"),
|
||||||
|
window_overlap,
|
||||||
|
2 * window_overlap + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# group batch_size and num_heads dimensions into one
|
# group batch_size and num_heads dimensions into one
|
||||||
@@ -577,9 +610,12 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
||||||
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
|
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
|
||||||
|
|
||||||
|
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
|
||||||
|
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
|
||||||
attn_probs_from_global_key[
|
attn_probs_from_global_key[
|
||||||
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
|
is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
|
||||||
] = torch.finfo(attn_probs_from_global_key.dtype).min
|
] = torch.finfo(attn_probs_from_global_key.dtype).min
|
||||||
|
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
|
||||||
|
|
||||||
return attn_probs_from_global_key
|
return attn_probs_from_global_key
|
||||||
|
|
||||||
@@ -673,9 +709,12 @@ class LEDEncoderSelfAttention(nn.Module):
|
|||||||
|
|
||||||
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||||
|
|
||||||
|
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
|
||||||
|
global_attn_scores = global_attn_scores.transpose(1, 2)
|
||||||
global_attn_scores[
|
global_attn_scores[
|
||||||
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
|
is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
|
||||||
] = torch.finfo(global_attn_scores.dtype).min
|
] = torch.finfo(global_attn_scores.dtype).min
|
||||||
|
global_attn_scores = global_attn_scores.transpose(1, 2)
|
||||||
|
|
||||||
global_attn_scores = global_attn_scores.masked_fill(
|
global_attn_scores = global_attn_scores.masked_fill(
|
||||||
is_index_masked[:, None, None, :],
|
is_index_masked[:, None, None, :],
|
||||||
|
|||||||
@@ -13,12 +13,20 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Longformer configuration"""
|
""" Longformer configuration"""
|
||||||
from typing import List, Union
|
from collections import OrderedDict
|
||||||
|
from typing import TYPE_CHECKING, Any, List, Mapping, Optional, Union
|
||||||
|
|
||||||
from ...utils import logging
|
from ...onnx import OnnxConfig
|
||||||
|
from ...utils import TensorType, logging
|
||||||
from ..roberta.configuration_roberta import RobertaConfig
|
from ..roberta.configuration_roberta import RobertaConfig
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ...configuration_utils import PretrainedConfig
|
||||||
|
from ...onnx.config import PatchingSpec
|
||||||
|
from ...tokenization_utils_base import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
||||||
@@ -71,6 +79,69 @@ class LongformerConfig(RobertaConfig):
|
|||||||
```"""
|
```"""
|
||||||
model_type = "longformer"
|
model_type = "longformer"
|
||||||
|
|
||||||
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
|
def __init__(
|
||||||
|
self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, onnx_export: bool = False, **kwargs
|
||||||
|
):
|
||||||
super().__init__(sep_token_id=sep_token_id, **kwargs)
|
super().__init__(sep_token_id=sep_token_id, **kwargs)
|
||||||
self.attention_window = attention_window
|
self.attention_window = attention_window
|
||||||
|
self.onnx_export = onnx_export
|
||||||
|
|
||||||
|
|
||||||
|
class LongformerOnnxConfig(OnnxConfig):
|
||||||
|
def __init__(self, config: "PretrainedConfig", task: str = "default", patching_specs: "List[PatchingSpec]" = None):
|
||||||
|
super().__init__(config, task, patching_specs)
|
||||||
|
config.onnx_export = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
if self.task == "multiple-choice":
|
||||||
|
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
|
||||||
|
else:
|
||||||
|
dynamic_axis = {0: "batch", 1: "sequence"}
|
||||||
|
return OrderedDict(
|
||||||
|
[
|
||||||
|
("input_ids", dynamic_axis),
|
||||||
|
("attention_mask", dynamic_axis),
|
||||||
|
("global_attention_mask", dynamic_axis),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||||
|
outputs = super().outputs
|
||||||
|
if self.task == "default":
|
||||||
|
outputs["pooler_output"] = {0: "batch"}
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def atol_for_validation(self) -> float:
|
||||||
|
"""
|
||||||
|
What absolute tolerance value to use during model conversion validation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Float absolute tolerance value.
|
||||||
|
"""
|
||||||
|
return 1e-4
|
||||||
|
|
||||||
|
@property
|
||||||
|
def default_onnx_opset(self) -> int:
|
||||||
|
# needs to be >= 14 to support tril operator
|
||||||
|
return max(super().default_onnx_opset, 14)
|
||||||
|
|
||||||
|
def generate_dummy_inputs(
|
||||||
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizerBase",
|
||||||
|
batch_size: int = -1,
|
||||||
|
seq_length: int = -1,
|
||||||
|
is_pair: bool = False,
|
||||||
|
framework: Optional[TensorType] = None,
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
inputs = super().generate_dummy_inputs(
|
||||||
|
preprocessor=tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
|
||||||
|
)
|
||||||
|
import torch
|
||||||
|
|
||||||
|
inputs["global_attention_mask"] = torch.zeros_like(inputs["input_ids"])
|
||||||
|
# make every second token global
|
||||||
|
inputs["global_attention_mask"][:, ::2] = 1
|
||||||
|
return inputs
|
||||||
|
|||||||
@@ -532,6 +532,8 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
|
|
||||||
self.one_sided_attn_window_size = attention_window // 2
|
self.one_sided_attn_window_size = attention_window // 2
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -761,17 +763,16 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
return chunked_hidden_states
|
return chunked_hidden_states
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _chunk(hidden_states, window_overlap):
|
def _chunk(hidden_states, window_overlap, onnx_export=False):
|
||||||
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
|
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
|
||||||
|
if not onnx_export:
|
||||||
# non-overlapping chunks of size = 2w
|
# non-overlapping chunks of size = 2w
|
||||||
hidden_states = hidden_states.view(
|
hidden_states = hidden_states.view(
|
||||||
hidden_states.size(0),
|
hidden_states.size(0),
|
||||||
hidden_states.size(1) // (window_overlap * 2),
|
torch.div(hidden_states.size(1), (window_overlap * 2), rounding_mode="trunc"),
|
||||||
window_overlap * 2,
|
window_overlap * 2,
|
||||||
hidden_states.size(2),
|
hidden_states.size(2),
|
||||||
)
|
)
|
||||||
|
|
||||||
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
|
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
|
||||||
chunk_size = list(hidden_states.size())
|
chunk_size = list(hidden_states.size())
|
||||||
chunk_size[1] = chunk_size[1] * 2 - 1
|
chunk_size[1] = chunk_size[1] * 2 - 1
|
||||||
@@ -780,6 +781,31 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
chunk_stride[1] = chunk_stride[1] // 2
|
chunk_stride[1] = chunk_stride[1] // 2
|
||||||
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
|
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
|
||||||
|
|
||||||
|
# When exporting to ONNX, use this separate logic
|
||||||
|
if hidden_states.size(1) == window_overlap * 2:
|
||||||
|
# simplest case
|
||||||
|
return hidden_states.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# have to use slow implementation since as_strided, unfold and 2d-tensor indexing aren't supported (yet) in ONNX export
|
||||||
|
|
||||||
|
# TODO replace this with
|
||||||
|
# > return hidden_states.unfold(dimension=1, size=window_overlap * 2, step=window_overlap).transpose(2, 3)
|
||||||
|
# once `unfold` is supported
|
||||||
|
|
||||||
|
chunk_size = [
|
||||||
|
hidden_states.size(0),
|
||||||
|
hidden_states.size(1) // window_overlap - 1,
|
||||||
|
window_overlap * 2,
|
||||||
|
hidden_states.size(2),
|
||||||
|
]
|
||||||
|
|
||||||
|
overlapping_chunks = torch.empty(chunk_size)
|
||||||
|
for chunk in range(chunk_size[1]):
|
||||||
|
overlapping_chunks[:, chunk, :, :] = hidden_states[
|
||||||
|
:, chunk * window_overlap : chunk * window_overlap + 2 * window_overlap, :
|
||||||
|
]
|
||||||
|
return overlapping_chunks
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
|
def _mask_invalid_locations(input_tensor, affected_seq_len) -> torch.Tensor:
|
||||||
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
|
beginning_mask_2d = input_tensor.new_ones(affected_seq_len, affected_seq_len + 1).tril().flip(dims=[0])
|
||||||
@@ -787,10 +813,14 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
ending_mask = beginning_mask.flip(dims=(1, 3))
|
ending_mask = beginning_mask.flip(dims=(1, 3))
|
||||||
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
|
beginning_input = input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1]
|
||||||
beginning_mask = beginning_mask.expand(beginning_input.size())
|
beginning_mask = beginning_mask.expand(beginning_input.size())
|
||||||
beginning_input.masked_fill_(beginning_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
|
input_tensor[:, :affected_seq_len, :, : affected_seq_len + 1] = torch.full_like(
|
||||||
|
beginning_input, -float("inf")
|
||||||
|
).where(beginning_mask.bool(), beginning_input)
|
||||||
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
|
ending_input = input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :]
|
||||||
ending_mask = ending_mask.expand(ending_input.size())
|
ending_mask = ending_mask.expand(ending_input.size())
|
||||||
ending_input.masked_fill_(ending_mask == 1, -float("inf")) # `== 1` converts to bool or uint8
|
input_tensor[:, -affected_seq_len:, :, -(affected_seq_len + 1) :] = torch.full_like(
|
||||||
|
ending_input, -float("inf")
|
||||||
|
).where(ending_mask.bool(), ending_input)
|
||||||
|
|
||||||
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
|
def _sliding_chunks_query_key_matmul(self, query: torch.Tensor, key: torch.Tensor, window_overlap: int):
|
||||||
"""
|
"""
|
||||||
@@ -804,14 +834,14 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
|
), f"Sequence length should be multiple of {window_overlap * 2}. Given {seq_len}"
|
||||||
assert query.size() == key.size()
|
assert query.size() == key.size()
|
||||||
|
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
|
||||||
|
|
||||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
|
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size window_overlap * 2
|
||||||
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
query = query.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
||||||
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
key = key.transpose(1, 2).reshape(batch_size * num_heads, seq_len, head_dim)
|
||||||
|
|
||||||
query = self._chunk(query, window_overlap)
|
query = self._chunk(query, window_overlap, self.config.__dict__.get("onnx_export", False))
|
||||||
key = self._chunk(key, window_overlap)
|
key = self._chunk(key, window_overlap, self.config.__dict__.get("onnx_export", False))
|
||||||
|
|
||||||
# matrix multiplication
|
# matrix multiplication
|
||||||
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
|
# bcxd: batch_size * num_heads x chunks x 2window_overlap x head_dim
|
||||||
@@ -829,7 +859,7 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
# window_overlap previous words). The following column is attention score from each word to itself, then
|
# window_overlap previous words). The following column is attention score from each word to itself, then
|
||||||
# followed by window_overlap columns for the upper triangle.
|
# followed by window_overlap columns for the upper triangle.
|
||||||
|
|
||||||
diagonal_attention_scores = diagonal_chunked_attention_scores.new_empty(
|
diagonal_attention_scores = diagonal_chunked_attention_scores.new_zeros(
|
||||||
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
|
(batch_size * num_heads, chunks_count + 1, window_overlap, window_overlap * 2 + 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -870,11 +900,14 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
assert seq_len % (window_overlap * 2) == 0
|
assert seq_len % (window_overlap * 2) == 0
|
||||||
assert attn_probs.size()[:3] == value.size()[:3]
|
assert attn_probs.size()[:3] == value.size()[:3]
|
||||||
assert attn_probs.size(3) == 2 * window_overlap + 1
|
assert attn_probs.size(3) == 2 * window_overlap + 1
|
||||||
chunks_count = seq_len // window_overlap - 1
|
chunks_count = torch.div(seq_len, window_overlap, rounding_mode="trunc") - 1
|
||||||
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
|
# group batch_size and num_heads dimensions into one, then chunk seq_len into chunks of size 2 window overlap
|
||||||
|
|
||||||
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
|
chunked_attn_probs = attn_probs.transpose(1, 2).reshape(
|
||||||
batch_size * num_heads, seq_len // window_overlap, window_overlap, 2 * window_overlap + 1
|
batch_size * num_heads,
|
||||||
|
torch.div(seq_len, window_overlap, rounding_mode="trunc"),
|
||||||
|
window_overlap,
|
||||||
|
2 * window_overlap + 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
# group batch_size and num_heads dimensions into one
|
# group batch_size and num_heads dimensions into one
|
||||||
@@ -949,9 +982,12 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
# (batch_size, seq_len, num_heads, max_num_global_attn_indices)
|
||||||
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
|
attn_probs_from_global_key = torch.einsum("blhd,bshd->blhs", (query_vectors, key_vectors_only_global))
|
||||||
|
|
||||||
|
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
|
||||||
|
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
|
||||||
attn_probs_from_global_key[
|
attn_probs_from_global_key[
|
||||||
is_local_index_no_global_attn_nonzero[0], :, :, is_local_index_no_global_attn_nonzero[1]
|
is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
|
||||||
] = torch.finfo(attn_probs_from_global_key.dtype).min
|
] = torch.finfo(attn_probs_from_global_key.dtype).min
|
||||||
|
attn_probs_from_global_key = attn_probs_from_global_key.transpose(1, 3)
|
||||||
|
|
||||||
return attn_probs_from_global_key
|
return attn_probs_from_global_key
|
||||||
|
|
||||||
@@ -1045,9 +1081,12 @@ class LongformerSelfAttention(nn.Module):
|
|||||||
|
|
||||||
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
global_attn_scores = global_attn_scores.view(batch_size, self.num_heads, max_num_global_attn_indices, seq_len)
|
||||||
|
|
||||||
|
# need to transpose since ONNX export only supports consecutive indexing: https://pytorch.org/docs/stable/onnx.html#writes-sets
|
||||||
|
global_attn_scores = global_attn_scores.transpose(1, 2)
|
||||||
global_attn_scores[
|
global_attn_scores[
|
||||||
is_local_index_no_global_attn_nonzero[0], :, is_local_index_no_global_attn_nonzero[1], :
|
is_local_index_no_global_attn_nonzero[0], is_local_index_no_global_attn_nonzero[1], :, :
|
||||||
] = torch.finfo(global_attn_scores.dtype).min
|
] = torch.finfo(global_attn_scores.dtype).min
|
||||||
|
global_attn_scores = global_attn_scores.transpose(1, 2)
|
||||||
|
|
||||||
global_attn_scores = global_attn_scores.masked_fill(
|
global_attn_scores = global_attn_scores.masked_fill(
|
||||||
is_index_masked[:, None, None, :],
|
is_index_masked[:, None, None, :],
|
||||||
@@ -1588,7 +1627,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
|||||||
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
|
inputs_embeds = torch.cat([inputs_embeds, inputs_embeds_padding], dim=-2)
|
||||||
|
|
||||||
attention_mask = nn.functional.pad(
|
attention_mask = nn.functional.pad(
|
||||||
attention_mask, (0, padding_len), value=False
|
attention_mask, (0, padding_len), value=0
|
||||||
) # no attention on the padding tokens
|
) # no attention on the padding tokens
|
||||||
token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
|
token_type_ids = nn.functional.pad(token_type_ids, (0, padding_len), value=0) # pad with token_type_id = 0
|
||||||
|
|
||||||
|
|||||||
@@ -358,6 +358,15 @@ class FeaturesManager:
|
|||||||
"seq2seq-lm-with-past",
|
"seq2seq-lm-with-past",
|
||||||
onnx_config_cls="models.longt5.LongT5OnnxConfig",
|
onnx_config_cls="models.longt5.LongT5OnnxConfig",
|
||||||
),
|
),
|
||||||
|
"longformer": supported_features_mapping(
|
||||||
|
"default",
|
||||||
|
"masked-lm",
|
||||||
|
"multiple-choice",
|
||||||
|
"question-answering",
|
||||||
|
"sequence-classification",
|
||||||
|
"token-classification",
|
||||||
|
onnx_config_cls="models.longformer.LongformerOnnxConfig",
|
||||||
|
),
|
||||||
"marian": supported_features_mapping(
|
"marian": supported_features_mapping(
|
||||||
"default",
|
"default",
|
||||||
"default-with-past",
|
"default-with-past",
|
||||||
|
|||||||
@@ -212,6 +212,7 @@ PYTORCH_EXPORT_MODELS = {
|
|||||||
("data2vec-vision", "facebook/data2vec-vision-base"),
|
("data2vec-vision", "facebook/data2vec-vision-base"),
|
||||||
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
|
("perceiver", "deepmind/language-perceiver", ("masked-lm", "sequence-classification")),
|
||||||
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
|
("perceiver", "deepmind/vision-perceiver-conv", ("image-classification",)),
|
||||||
|
("longformer", "allenai/longformer-base-4096"),
|
||||||
("yolos", "hustvl/yolos-tiny"),
|
("yolos", "hustvl/yolos-tiny"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user