diff --git a/examples/language-modeling/run_plm.py b/examples/language-modeling/run_plm.py index 0e264115d8..cb64716d26 100644 --- a/examples/language-modeling/run_plm.py +++ b/examples/language-modeling/run_plm.py @@ -93,11 +93,11 @@ class DataTrainingArguments: overwrite_cache: bool = field( default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} ) - max_seq_length: Optional[int] = field( - default=None, + max_seq_length: int = field( + default=512, metadata={ "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated. Default to the max input length of the model." + "than this will be truncated." }, ) preprocessing_num_workers: Optional[int] = field( @@ -286,15 +286,12 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, ) - if data_args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - else: - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warn( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of # max_seq_length.