[bart] add config.extra_pos_embeddings to facilitate reuse (#5190)
This commit is contained in:
@@ -41,6 +41,7 @@ class BartConfig(PretrainedConfig):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
activation_dropout=0.0,
|
activation_dropout=0.0,
|
||||||
|
extra_pos_embeddings=2,
|
||||||
activation_function="gelu",
|
activation_function="gelu",
|
||||||
vocab_size=50265,
|
vocab_size=50265,
|
||||||
d_model=1024,
|
d_model=1024,
|
||||||
@@ -118,6 +119,9 @@ class BartConfig(PretrainedConfig):
|
|||||||
# Classifier stuff
|
# Classifier stuff
|
||||||
self.classif_dropout = classifier_dropout
|
self.classif_dropout = classifier_dropout
|
||||||
|
|
||||||
|
# pos embedding offset
|
||||||
|
self.extra_pos_embeddings = self.pad_token_id + 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_attention_heads(self) -> int:
|
def num_attention_heads(self) -> int:
|
||||||
return self.encoder_attention_heads
|
return self.encoder_attention_heads
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
|
|||||||
from .activations import ACT2FN
|
from .activations import ACT2FN
|
||||||
from .configuration_bart import BartConfig
|
from .configuration_bart import BartConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
|
from .modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -96,6 +96,7 @@ BART_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
|
|
||||||
def invert_mask(attention_mask):
|
def invert_mask(attention_mask):
|
||||||
|
"""Turns 1->0, 0->1, False->True, True-> False"""
|
||||||
assert attention_mask.dim() == 2
|
assert attention_mask.dim() == 2
|
||||||
return attention_mask.eq(0)
|
return attention_mask.eq(0)
|
||||||
|
|
||||||
@@ -261,7 +262,7 @@ class BartEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.embed_positions = LearnedPositionalEmbedding(
|
self.embed_positions = LearnedPositionalEmbedding(
|
||||||
config.max_position_embeddings, embed_dim, self.padding_idx,
|
config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||||
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
||||||
@@ -435,7 +436,7 @@ class BartDecoder(nn.Module):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.embed_positions = LearnedPositionalEmbedding(
|
self.embed_positions = LearnedPositionalEmbedding(
|
||||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings,
|
||||||
)
|
)
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||||
@@ -745,23 +746,23 @@ class LearnedPositionalEmbedding(nn.Embedding):
|
|||||||
position ids are passed to the forward function.
|
position ids are passed to the forward function.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
|
||||||
self, num_embeddings: int, embedding_dim: int, padding_idx: int,
|
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||||
):
|
# and adjust num_embeddings appropriately. Other models dont have this hack
|
||||||
# if padding_idx is specified then offset the embedding ids by
|
self.offset = offset
|
||||||
# this index and adjust num_embeddings appropriately
|
|
||||||
assert padding_idx is not None
|
assert padding_idx is not None
|
||||||
num_embeddings += padding_idx + 1 # WHY?
|
num_embeddings += offset
|
||||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||||
|
|
||||||
def forward(self, input, use_cache=False):
|
def forward(self, input_ids, use_cache=False):
|
||||||
"""Input is expected to be of size [bsz x seqlen]."""
|
"""Input is expected to be of size [bsz x seqlen]."""
|
||||||
if use_cache: # the position is our current step in the decoded sequence
|
bsz, seq_len = input_ids.shape[:2]
|
||||||
pos = int(self.padding_idx + input.size(1))
|
if use_cache:
|
||||||
positions = input.data.new(1, 1).fill_(pos)
|
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
|
||||||
else:
|
else:
|
||||||
positions = create_position_ids_from_input_ids(input, self.padding_idx)
|
# starts at 0, ends at 1-seq_len
|
||||||
return super().forward(positions)
|
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
|
||||||
|
return super().forward(positions + self.offset)
|
||||||
|
|
||||||
|
|
||||||
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
|||||||
from .configuration_roberta import RobertaConfig
|
from .configuration_roberta import RobertaConfig
|
||||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
|
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
|
||||||
from .modeling_utils import create_position_ids_from_input_ids
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -733,3 +732,17 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
|
|||||||
outputs = (total_loss,) + outputs
|
outputs = (total_loss,) + outputs
|
||||||
|
|
||||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||||
|
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||||
|
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||||
|
`utils.make_positions`.
|
||||||
|
|
||||||
|
:param torch.Tensor x:
|
||||||
|
:return torch.Tensor:
|
||||||
|
"""
|
||||||
|
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||||
|
mask = input_ids.ne(padding_idx).int()
|
||||||
|
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||||
|
return incremental_indices.long() + padding_idx
|
||||||
|
|||||||
@@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
|
||||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
|
||||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
|
||||||
`utils.make_positions`.
|
|
||||||
|
|
||||||
:param torch.Tensor x:
|
|
||||||
:return torch.Tensor:
|
|
||||||
"""
|
|
||||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
|
||||||
mask = input_ids.ne(padding_idx).int()
|
|
||||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
|
||||||
return incremental_indices.long() + padding_idx
|
|
||||||
|
|
||||||
|
|
||||||
def prune_linear_layer(layer, index, dim=0):
|
def prune_linear_layer(layer, index, dim=0):
|
||||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||||
Return the pruned layer as a new layer with requires_grad=True.
|
Return the pruned layer as a new layer with requires_grad=True.
|
||||||
|
|||||||
@@ -34,9 +34,8 @@ if is_torch_available():
|
|||||||
RobertaForSequenceClassification,
|
RobertaForSequenceClassification,
|
||||||
RobertaForTokenClassification,
|
RobertaForTokenClassification,
|
||||||
)
|
)
|
||||||
from transformers.modeling_roberta import RobertaEmbeddings
|
from transformers.modeling_roberta import RobertaEmbeddings, create_position_ids_from_input_ids
|
||||||
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
|
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
from transformers.modeling_utils import create_position_ids_from_input_ids
|
|
||||||
|
|
||||||
|
|
||||||
class RobertaModelTester:
|
class RobertaModelTester:
|
||||||
|
|||||||
Reference in New Issue
Block a user