Fix max length in run_plm script (#8738)

This commit is contained in:
Sylvain Gugger
2020-11-23 16:02:31 -05:00
committed by GitHub
parent e84786aaa6
commit 367f497dec

View File

@@ -93,11 +93,11 @@ class DataTrainingArguments:
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
max_seq_length: Optional[int] = field( max_seq_length: int = field(
default=None, default=512,
metadata={ metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer " "help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model." "than this will be truncated."
}, },
) )
preprocessing_num_workers: Optional[int] = field( preprocessing_num_workers: Optional[int] = field(
@@ -286,15 +286,12 @@ def main():
load_from_cache_file=not data_args.overwrite_cache, load_from_cache_file=not data_args.overwrite_cache,
) )
if data_args.max_seq_length is None: if data_args.max_seq_length > tokenizer.model_max_length:
max_seq_length = tokenizer.model_max_length logger.warn(
else: f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
if data_args.max_seq_length > tokenizer.model_max_length: f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
logger.warn( )
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Main data processing function that will concatenate all texts from our dataset and generate chunks of # Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length. # max_seq_length.