[Pretrained Model] Add resize_position_embeddings (#13559)
* finish * delete bogus file * correct some stuff * finish * finish
This commit is contained in:
committed by
GitHub
parent
c783e14887
commit
95f933ea85
@@ -99,6 +99,13 @@ class ModelArguments:
|
||||
"with private models)."
|
||||
},
|
||||
)
|
||||
resize_position_embeddings: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
|
||||
"the model's position embeddings."
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -366,6 +373,25 @@ def main():
|
||||
if model.config.decoder_start_token_id is None:
|
||||
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
||||
|
||||
if (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
and model.config.max_position_embeddings < data_args.max_source_length
|
||||
):
|
||||
if model_args.resize_position_embeddings is None:
|
||||
logger.warning(
|
||||
f"Increasing the model's number of position embedding vectors from {model.config.max_position_embedding} "
|
||||
f"to {data_args.max_source_length}."
|
||||
)
|
||||
model.resize_position_embeddings(data_args.max_source_length)
|
||||
elif model_args.resize_position_embeddings:
|
||||
model.resize_position_embeddings(data_args.max_source_length)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has {model.config.max_position_embeddings}"
|
||||
f" position encodings. Consider either reducing `--max_source_length` to {model.config.max_position_embeddings} or to automatically "
|
||||
"resize the model's position encodings by passing `--resize_position_embeddings`."
|
||||
)
|
||||
|
||||
prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
|
||||
|
||||
# Preprocessing the datasets.
|
||||
|
||||
Reference in New Issue
Block a user