[bart] add config.extra_pos_embeddings to facilitate reuse (#5190)
This commit is contained in:
@@ -41,6 +41,7 @@ class BartConfig(PretrainedConfig):
|
||||
def __init__(
|
||||
self,
|
||||
activation_dropout=0.0,
|
||||
extra_pos_embeddings=2,
|
||||
activation_function="gelu",
|
||||
vocab_size=50265,
|
||||
d_model=1024,
|
||||
@@ -118,6 +119,9 @@ class BartConfig(PretrainedConfig):
|
||||
# Classifier stuff
|
||||
self.classif_dropout = classifier_dropout
|
||||
|
||||
# pos embedding offset
|
||||
self.extra_pos_embeddings = self.pad_token_id + 1
|
||||
|
||||
@property
|
||||
def num_attention_heads(self) -> int:
|
||||
return self.encoder_attention_heads
|
||||
|
||||
@@ -28,7 +28,7 @@ from torch.nn import CrossEntropyLoss
|
||||
from .activations import ACT2FN
|
||||
from .configuration_bart import BartConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_utils import PreTrainedModel, create_position_ids_from_input_ids
|
||||
from .modeling_utils import PreTrainedModel
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -96,6 +96,7 @@ BART_INPUTS_DOCSTRING = r"""
|
||||
|
||||
|
||||
def invert_mask(attention_mask):
|
||||
"""Turns 1->0, 0->1, False->True, True-> False"""
|
||||
assert attention_mask.dim() == 2
|
||||
return attention_mask.eq(0)
|
||||
|
||||
@@ -261,7 +262,7 @@ class BartEncoder(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, embed_dim, self.padding_idx,
|
||||
config.max_position_embeddings, embed_dim, self.padding_idx, config.extra_pos_embeddings,
|
||||
)
|
||||
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.encoder_layers)])
|
||||
self.layernorm_embedding = LayerNorm(embed_dim) if config.normalize_embedding else nn.Identity()
|
||||
@@ -435,7 +436,7 @@ class BartDecoder(nn.Module):
|
||||
)
|
||||
else:
|
||||
self.embed_positions = LearnedPositionalEmbedding(
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx,
|
||||
config.max_position_embeddings, config.d_model, self.padding_idx, config.extra_pos_embeddings,
|
||||
)
|
||||
self.layers = nn.ModuleList(
|
||||
[DecoderLayer(config) for _ in range(config.decoder_layers)]
|
||||
@@ -745,23 +746,23 @@ class LearnedPositionalEmbedding(nn.Embedding):
|
||||
position ids are passed to the forward function.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_embeddings: int, embedding_dim: int, padding_idx: int,
|
||||
):
|
||||
# if padding_idx is specified then offset the embedding ids by
|
||||
# this index and adjust num_embeddings appropriately
|
||||
def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, offset):
|
||||
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
||||
# and adjust num_embeddings appropriately. Other models dont have this hack
|
||||
self.offset = offset
|
||||
assert padding_idx is not None
|
||||
num_embeddings += padding_idx + 1 # WHY?
|
||||
num_embeddings += offset
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
||||
|
||||
def forward(self, input, use_cache=False):
|
||||
def forward(self, input_ids, use_cache=False):
|
||||
"""Input is expected to be of size [bsz x seqlen]."""
|
||||
if use_cache: # the position is our current step in the decoded sequence
|
||||
pos = int(self.padding_idx + input.size(1))
|
||||
positions = input.data.new(1, 1).fill_(pos)
|
||||
bsz, seq_len = input_ids.shape[:2]
|
||||
if use_cache:
|
||||
positions = input_ids.data.new(1, 1).fill_(seq_len - 1) # called before slicing
|
||||
else:
|
||||
positions = create_position_ids_from_input_ids(input, self.padding_idx)
|
||||
return super().forward(positions)
|
||||
# starts at 0, ends at 1-seq_len
|
||||
positions = torch.arange(seq_len, dtype=torch.long, device=self.weight.device)
|
||||
return super().forward(positions + self.offset)
|
||||
|
||||
|
||||
def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True):
|
||||
|
||||
@@ -26,7 +26,6 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from .configuration_roberta import RobertaConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_bert import BertEmbeddings, BertLayerNorm, BertModel, BertPreTrainedModel, gelu
|
||||
from .modeling_utils import create_position_ids_from_input_ids
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -733,3 +732,17 @@ class RobertaForQuestionAnswering(BertPreTrainedModel):
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
`utils.make_positions`.
|
||||
|
||||
:param torch.Tensor x:
|
||||
:return torch.Tensor:
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
@@ -2090,20 +2090,6 @@ class SequenceSummary(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
def create_position_ids_from_input_ids(input_ids, padding_idx):
|
||||
""" Replace non-padding symbols with their position numbers. Position numbers begin at
|
||||
padding_idx+1. Padding symbols are ignored. This is modified from fairseq's
|
||||
`utils.make_positions`.
|
||||
|
||||
:param torch.Tensor x:
|
||||
:return torch.Tensor:
|
||||
"""
|
||||
# The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
|
||||
mask = input_ids.ne(padding_idx).int()
|
||||
incremental_indices = torch.cumsum(mask, dim=1).type_as(mask) * mask
|
||||
return incremental_indices.long() + padding_idx
|
||||
|
||||
|
||||
def prune_linear_layer(layer, index, dim=0):
|
||||
""" Prune a linear layer (a model parameters) to keep only entries in index.
|
||||
Return the pruned layer as a new layer with requires_grad=True.
|
||||
|
||||
@@ -34,9 +34,8 @@ if is_torch_available():
|
||||
RobertaForSequenceClassification,
|
||||
RobertaForTokenClassification,
|
||||
)
|
||||
from transformers.modeling_roberta import RobertaEmbeddings
|
||||
from transformers.modeling_roberta import RobertaEmbeddings, create_position_ids_from_input_ids
|
||||
from transformers.modeling_roberta import ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
from transformers.modeling_utils import create_position_ids_from_input_ids
|
||||
|
||||
|
||||
class RobertaModelTester:
|
||||
|
||||
Reference in New Issue
Block a user