[Pretrained Model] Add resize_position_embeddings (#13559)

* finish

* delete bogus file

* correct some stuff

* finish

* finish
This commit is contained in:
Patrick von Platen
2021-09-15 19:03:56 +02:00
committed by GitHub
parent c783e14887
commit 95f933ea85
8 changed files with 413 additions and 13 deletions

View File

@@ -99,6 +99,13 @@ class ModelArguments:
"with private models)."
},
)
resize_position_embeddings: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
"the model's position embeddings."
},
)
@dataclass
@@ -366,6 +373,25 @@ def main():
if model.config.decoder_start_token_id is None:
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
if (
hasattr(model.config, "max_position_embeddings")
and model.config.max_position_embeddings < data_args.max_source_length
):
if model_args.resize_position_embeddings is None:
logger.warning(
f"Increasing the model's number of position embedding vectors from {model.config.max_position_embedding} "
f"to {data_args.max_source_length}."
)
model.resize_position_embeddings(data_args.max_source_length)
elif model_args.resize_position_embeddings:
model.resize_position_embeddings(data_args.max_source_length)
else:
raise ValueError(
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
"resize the model's position encodings by passing `--resize_position_embeddings`."
)
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
# Preprocessing the datasets.