[Pretrained Model] Add resize_position_embeddings (#13559)
* finish * delete bogus file * correct some stuff * finish * finish
This commit is contained in:
committed by
GitHub
parent
c783e14887
commit
95f933ea85
@@ -99,6 +99,13 @@ class ModelArguments:
|
|||||||
"with private models)."
|
"with private models)."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
resize_position_embeddings: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
|
||||||
|
"the model's position embeddings."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -366,6 +373,25 @@ def main():
|
|||||||
if model.config.decoder_start_token_id is None:
|
if model.config.decoder_start_token_id is None:
|
||||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||||
|
|
||||||
|
if (
|
||||||
|
hasattr(model.config, "max_position_embeddings")
|
||||||
|
and model.config.max_position_embeddings < data_args.max_source_length
|
||||||
|
):
|
||||||
|
if model_args.resize_position_embeddings is None:
|
||||||
|
logger.warning(
|
||||||
|
f"Increasing the model's number of position embedding vectors from {model.config.max_position_embedding} "
|
||||||
|
f"to {data_args.max_source_length}."
|
||||||
|
)
|
||||||
|
model.resize_position_embeddings(data_args.max_source_length)
|
||||||
|
elif model_args.resize_position_embeddings:
|
||||||
|
model.resize_position_embeddings(data_args.max_source_length)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
|
||||||
|
f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
|
||||||
|
"resize the model's position encodings by passing `--resize_position_embeddings`."
|
||||||
|
)
|
||||||
|
|
||||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
|
|||||||
@@ -887,6 +887,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
|
|
||||||
return new_lm_head
|
return new_lm_head
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
||||||
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
|
||||||
|
f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
|
||||||
|
)
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
"""
|
"""
|
||||||
If needed prunes and maybe initializes weights.
|
If needed prunes and maybe initializes weights.
|
||||||
|
|||||||
@@ -2833,7 +2833,7 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
|
|||||||
return self.decoder(*args, **kwargs)
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.pegasus.modeling_pegasus.PegasusForCausalLM with Pegasus->BigBirdPegasus, 'facebook/bart-large'->"google/bigbird-pegasus-large-arxiv"
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BigBirdPegasus, 'facebook/bart-large'->"google/bigbird-pegasus-large-arxiv"
|
||||||
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
|
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
|
|||||||
@@ -442,6 +442,67 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings
|
||||||
|
"""
|
||||||
|
return self.embeddings.position_embeddings
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embedding matrix. If position embeddings are learned, increasing the size
|
||||||
|
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
|
||||||
|
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
|
||||||
|
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
|
||||||
|
the size will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
num_position_embeds_diff = new_num_position_embeddings - self.config.max_position_embeddings
|
||||||
|
|
||||||
|
# no resizing needs to be done if the length stays the same
|
||||||
|
if num_position_embeds_diff == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
|
||||||
|
self.config.max_position_embeddings = new_num_position_embeddings
|
||||||
|
|
||||||
|
old_position_embeddings_weight = self.embeddings.position_embeddings.weight.clone()
|
||||||
|
|
||||||
|
self.embeddings.position_embeddings = nn.Embedding(self.config.max_position_embeddings, self.config.dim)
|
||||||
|
|
||||||
|
if self.config.sinusoidal_pos_embds:
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
import deepspeed
|
||||||
|
|
||||||
|
with deepspeed.zero.GatheredParameters(self.embeddings.position_embeddings.weight, modifier_rank=0):
|
||||||
|
if torch.distributed.get_rank() == 0:
|
||||||
|
create_sinusoidal_embeddings(
|
||||||
|
n_pos=self.config.max_position_embeddings,
|
||||||
|
dim=self.config.dim,
|
||||||
|
out=self.embeddings.position_embeddings.weight,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
create_sinusoidal_embeddings(
|
||||||
|
n_pos=self.config.max_position_embeddings,
|
||||||
|
dim=self.config.dim,
|
||||||
|
out=self.embeddings.position_embeddings.weight,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
with torch.no_grad():
|
||||||
|
if num_position_embeds_diff > 0:
|
||||||
|
self.embeddings.position_embeddings.weight[:-num_position_embeds_diff] = nn.Parameter(
|
||||||
|
old_position_embeddings_weight
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embeddings.position_embeddings.weight = nn.Parameter(
|
||||||
|
old_position_embeddings_weight[:num_position_embeds_diff]
|
||||||
|
)
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.embeddings.word_embeddings
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
@@ -525,6 +586,27 @@ class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.mlm_loss_fct = nn.CrossEntropyLoss()
|
self.mlm_loss_fct = nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings
|
||||||
|
"""
|
||||||
|
return self.distilbert.get_position_embeddings()
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embedding matrix. If position embeddings are learned, increasing the size
|
||||||
|
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
|
||||||
|
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
|
||||||
|
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
|
||||||
|
the size will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.vocab_projector
|
return self.vocab_projector
|
||||||
|
|
||||||
@@ -608,6 +690,27 @@ class DistilBertForSequenceClassification(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings
|
||||||
|
"""
|
||||||
|
return self.distilbert.get_position_embeddings()
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embedding matrix. If position embeddings are learned, increasing the size
|
||||||
|
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
|
||||||
|
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
|
||||||
|
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
|
||||||
|
the size will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -703,6 +806,27 @@ class DistilBertForQuestionAnswering(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings
|
||||||
|
"""
|
||||||
|
return self.distilbert.get_position_embeddings()
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embedding matrix. If position embeddings are learned, increasing the size
|
||||||
|
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
|
||||||
|
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
|
||||||
|
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
|
||||||
|
the size will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
|
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices"))
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -799,6 +923,27 @@ class DistilBertForTokenClassification(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings
|
||||||
|
"""
|
||||||
|
return self.distilbert.get_position_embeddings()
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embedding matrix. If position embeddings are learned, increasing the size
|
||||||
|
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
|
||||||
|
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
|
||||||
|
size will add correct vectors at the end following the position encoding algorithm, whereas reducing
|
||||||
|
the size will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING)
|
||||||
@add_code_sample_docstrings(
|
@add_code_sample_docstrings(
|
||||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||||
@@ -883,6 +1028,27 @@ class DistilBertForMultipleChoice(DistilBertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings
|
||||||
|
"""
|
||||||
|
return self.distilbert.get_position_embeddings()
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`)
|
||||||
|
The number of new position embeddings. If position embeddings are learned, increasing the size will add
|
||||||
|
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
|
||||||
|
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
|
||||||
|
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
|
||||||
|
will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.distilbert.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(
|
@add_start_docstrings_to_model_forward(
|
||||||
DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -480,17 +480,6 @@ class PegasusPreTrainedModel(PreTrainedModel):
|
|||||||
if module.padding_idx is not None:
|
if module.padding_idx is not None:
|
||||||
module.weight.data[module.padding_idx].zero_()
|
module.weight.data[module.padding_idx].zero_()
|
||||||
|
|
||||||
@property
|
|
||||||
def dummy_inputs(self):
|
|
||||||
pad_token = self.config.pad_token_id
|
|
||||||
input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
|
|
||||||
dummy_inputs = {
|
|
||||||
"attention_mask": input_ids.ne(pad_token),
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"decoder_input_ids": input_ids,
|
|
||||||
}
|
|
||||||
return dummy_inputs
|
|
||||||
|
|
||||||
|
|
||||||
PEGASUS_START_DOCSTRING = r"""
|
PEGASUS_START_DOCSTRING = r"""
|
||||||
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
||||||
@@ -658,6 +647,34 @@ class PegasusEncoder(PegasusPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embeddings. If position embeddings are learned, increasing the size will add
|
||||||
|
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
|
||||||
|
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
|
||||||
|
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
|
||||||
|
will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
|
||||||
|
self.config.max_position_embeddings = new_num_position_embeddings
|
||||||
|
|
||||||
|
self.embed_positions = PegasusSinusoidalPositionalEmbedding(
|
||||||
|
self.config.max_position_embeddings,
|
||||||
|
self.config.d_model,
|
||||||
|
self.padding_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings matrix
|
||||||
|
"""
|
||||||
|
return self.embed_positions
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -848,6 +865,34 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
|||||||
|
|
||||||
return combined_attention_mask
|
return combined_attention_mask
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embeddings. If position embeddings are learned, increasing the size will add
|
||||||
|
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
|
||||||
|
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
|
||||||
|
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
|
||||||
|
will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
logger.info(f"Setting `config.max_position_embeddings={new_num_position_embeddings}`...")
|
||||||
|
self.config.max_position_embeddings = new_num_position_embeddings
|
||||||
|
|
||||||
|
self.embed_positions = PegasusSinusoidalPositionalEmbedding(
|
||||||
|
self.config.max_position_embeddings,
|
||||||
|
self.config.d_model,
|
||||||
|
self.padding_idx,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings matrix
|
||||||
|
"""
|
||||||
|
return self.embed_positions
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
@@ -1097,6 +1142,29 @@ class PegasusModel(PegasusPreTrainedModel):
|
|||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embeddings. If position embeddings are learned, increasing the size will add
|
||||||
|
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
|
||||||
|
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
|
||||||
|
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
|
||||||
|
will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.config.max_position_embeddings = new_num_position_embeddings
|
||||||
|
self.encoder.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
self.decoder.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> Tuple[nn.Embedding]:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings matrix
|
||||||
|
"""
|
||||||
|
return (self.encoder.get_position_embeddings(), self.decoder.get_position_embeddings())
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1237,6 +1305,29 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embeddings. If position embeddings are learned, increasing the size will add
|
||||||
|
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
|
||||||
|
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
|
||||||
|
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
|
||||||
|
will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.config.max_position_embeddings = new_num_position_embeddings
|
||||||
|
self.model.encoder.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> Tuple[nn.Embedding]:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings matrix
|
||||||
|
"""
|
||||||
|
return (self.model.encoder.get_position_embeddings(), self.model.decoder.get_position_embeddings())
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(PEGASUS_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
@add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
|
@add_end_docstrings(PEGASUS_GENERATION_EXAMPLE)
|
||||||
@@ -1373,7 +1464,6 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
|
|||||||
return self.decoder(*args, **kwargs)
|
return self.decoder(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Pegasus
|
|
||||||
class PegasusForCausalLM(PegasusPreTrainedModel):
|
class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
@@ -1404,7 +1494,30 @@ class PegasusForCausalLM(PegasusPreTrainedModel):
|
|||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.model.decoder
|
return self.model.decoder
|
||||||
|
|
||||||
|
def get_position_embeddings(self) -> nn.Embedding:
|
||||||
|
"""
|
||||||
|
Returns the position embeddings matrix
|
||||||
|
"""
|
||||||
|
return self.model.decoder.get_position_embeddings()
|
||||||
|
|
||||||
|
def resize_position_embeddings(self, new_num_position_embeddings: int):
|
||||||
|
"""
|
||||||
|
Resizes position embeddings matrix of the model if :obj:`new_num_position_embeddings !=
|
||||||
|
config.max_position_embeddings`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
new_num_position_embeddings (:obj:`int`):
|
||||||
|
The number of new position embeddings. If position embeddings are learned, increasing the size will add
|
||||||
|
newly initialized vectors at the end, whereas reducing the size will remove vectors from the end. If
|
||||||
|
position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the size will
|
||||||
|
add correct vectors at the end following the position encoding algorithm, whereas reducing the size
|
||||||
|
will remove vectors from the end.
|
||||||
|
"""
|
||||||
|
self.config.max_position_embeddings = new_num_position_embeddings
|
||||||
|
self.model.decoder.resize_position_embeddings(new_num_position_embeddings)
|
||||||
|
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
|
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM.forward with Bart->Pegasus
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
|
|||||||
@@ -94,6 +94,7 @@ class ModelTesterMixin:
|
|||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_pruning = True
|
test_pruning = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
|
test_resize_position_embeddings = False
|
||||||
test_head_masking = True
|
test_head_masking = True
|
||||||
test_missing_keys = True
|
test_missing_keys = True
|
||||||
test_model_parallel = False
|
test_model_parallel = False
|
||||||
@@ -1067,6 +1068,85 @@ class ModelTesterMixin:
|
|||||||
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
hidden_states_with_chunk = model(**self._prepare_for_class(inputs_dict, model_class))[0]
|
||||||
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
self.assertTrue(torch.allclose(hidden_states_no_chunk, hidden_states_with_chunk, atol=1e-3))
|
||||||
|
|
||||||
|
def test_resize_position_vector_embeddings(self):
|
||||||
|
if not self.test_resize_position_embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
(
|
||||||
|
original_config,
|
||||||
|
inputs_dict,
|
||||||
|
) = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
if self.model_tester.is_training is False:
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
max_position_embeddings = config.max_position_embeddings
|
||||||
|
|
||||||
|
# Retrieve the embeddings and clone theme
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
|
||||||
|
encoder_cloned_embeddings = encoder_model_embed.weight.clone()
|
||||||
|
decoder_cloned_embeddings = decoder_model_embed.weight.clone()
|
||||||
|
else:
|
||||||
|
model_embed = model.get_position_embeddings()
|
||||||
|
cloned_embeddings = model_embed.weight.clone()
|
||||||
|
|
||||||
|
# Check that resizing the position embeddings with a larger max_position_embeddings increases
|
||||||
|
# the model's postion embeddings size
|
||||||
|
model.resize_position_embeddings(max_position_embeddings + 10)
|
||||||
|
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings + 10)
|
||||||
|
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
|
||||||
|
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] + 10)
|
||||||
|
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] + 10)
|
||||||
|
else:
|
||||||
|
model_embed = model.get_position_embeddings()
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the position embeddings with a smaller max_position_embeddings decreases
|
||||||
|
# the model's max_position_embeddings
|
||||||
|
model.resize_position_embeddings(max_position_embeddings - 5)
|
||||||
|
self.assertEqual(model.config.max_position_embeddings, max_position_embeddings - 5)
|
||||||
|
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
encoder_model_embed, decoder_model_embed = model.get_position_embeddings()
|
||||||
|
self.assertEqual(encoder_model_embed.weight.shape[0], encoder_cloned_embeddings.shape[0] - 5)
|
||||||
|
self.assertEqual(decoder_model_embed.weight.shape[0], decoder_cloned_embeddings.shape[0] - 5)
|
||||||
|
else:
|
||||||
|
model_embed = model.get_position_embeddings()
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 5)
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||||
|
models_equal = True
|
||||||
|
|
||||||
|
if model.config.is_encoder_decoder:
|
||||||
|
for p1, p2 in zip(encoder_cloned_embeddings, encoder_model_embed.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
for p1, p2 in zip(decoder_cloned_embeddings, decoder_model_embed.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
else:
|
||||||
|
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
def test_resize_tokens_embeddings(self):
|
def test_resize_tokens_embeddings(self):
|
||||||
(
|
(
|
||||||
original_config,
|
original_config,
|
||||||
|
|||||||
@@ -214,6 +214,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
test_torchscript = True
|
test_torchscript = True
|
||||||
test_resize_embeddings = True
|
test_resize_embeddings = True
|
||||||
test_sequence_classification_problem_types = True
|
test_sequence_classification_problem_types = True
|
||||||
|
test_resize_position_embeddings = True
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.model_tester = DistilBertModelTester(self)
|
self.model_tester = DistilBertModelTester(self)
|
||||||
|
|||||||
@@ -229,6 +229,7 @@ class PegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
|
|||||||
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
|
all_model_classes = (PegasusModel, PegasusForConditionalGeneration) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
all_generative_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||||
is_encoder_decoder = True
|
is_encoder_decoder = True
|
||||||
|
test_resize_position_embeddings = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_missing_keys = False
|
test_missing_keys = False
|
||||||
|
|
||||||
@@ -526,6 +527,7 @@ class PegasusStandaloneDecoderModelTester:
|
|||||||
class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class PegasusStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else ()
|
all_model_classes = (PegasusDecoder, PegasusForCausalLM) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else ()
|
all_generative_model_classes = (PegasusForCausalLM,) if is_torch_available() else ()
|
||||||
|
test_resize_position_embeddings = True
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
is_encoder_decoder = False
|
is_encoder_decoder = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user