From 367f497dec8905ac37c2947be998f4469fdade6b Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 23 Nov 2020 16:02:31 -0500 Subject: [PATCH] Fix max length in run_plm script (#8738) --- examples/language-modeling/run_plm.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/examples/language-modeling/run_plm.py b/examples/language-modeling/run_plm.py index 0e264115d8..cb64716d26 100644 --- a/examples/language-modeling/run_plm.py +++ b/examples/language-modeling/run_plm.py @@ -93,11 +93,11 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) - max_seq_length: Optional[int] = field( - default=None, + max_seq_length: int = field( + default=512, metadata={ "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated. Default to the max input length of the model." + "than this will be truncated." }, ) preprocessing_num_workers: Optional[int] = field( @@ -286,15 +286,12 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, ) - if data_args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - else: - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warn( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of # max_seq_length.