From 27b819b0e32fd5967bcaf1e962fdb25cd43fe395 Mon Sep 17 00:00:00 2001 From: Russell Klopfer Date: Wed, 12 Jan 2022 08:57:00 -0500 Subject: [PATCH] use block_size instead of max_seq_length in tf run_clm example (#15036) * use block_size instead of max_seq_length * fixup * remove pad_to_block_size Co-authored-by: Russell Klopfer --- .../tensorflow/language-modeling/run_clm.py | 56 +++++++------------ 1 file changed, 19 insertions(+), 37 deletions(-) diff --git a/examples/tensorflow/language-modeling/run_clm.py b/examples/tensorflow/language-modeling/run_clm.py index 7b5e2ce08b..49a5b9ed77 100755 --- a/examples/tensorflow/language-modeling/run_clm.py +++ b/examples/tensorflow/language-modeling/run_clm.py @@ -148,11 +148,12 @@ class DataTrainingArguments: "help": "The percentage of the train set used as validation set in case there's no validation split" }, ) - max_seq_length: Optional[int] = field( + block_size: Optional[int] = field( default=None, metadata={ - "help": "The maximum total input sequence length after tokenization. Sequences longer " - "than this will be truncated." + "help": "Optional input sequence length after tokenization. " + "The training dataset will be truncated in block of this size for training. " + "Default to the model max input length for single sentence inputs (take into account special tokens)." }, ) preprocessing_num_workers: Optional[int] = field( @@ -166,13 +167,6 @@ class DataTrainingArguments: default=False, metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."}, ) - pad_to_max_length: bool = field( - default=False, - metadata={ - "help": "Whether to pad all samples to `max_seq_length`. " - "If False, will pad the samples dynamically when batching to the maximum length in the batch." - }, - ) max_train_samples: Optional[int] = field( default=None, metadata={ @@ -259,10 +253,6 @@ def main(): if training_args.output_dir is not None: training_args.output_dir = Path(training_args.output_dir) os.makedirs(training_args.output_dir, exist_ok=True) - - if isinstance(training_args.strategy, tf.distribute.TPUStrategy) and not data_args.pad_to_max_length: - logger.warning("We are training on TPU - forcing pad_to_max_length") - data_args.pad_to_max_length = True # endregion # region Checkpoints @@ -364,22 +354,6 @@ def main(): column_names = raw_datasets["train"].column_names text_column_name = "text" if "text" in column_names else column_names[0] - if data_args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - if max_seq_length > 1024: - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." - ) - max_seq_length = 1024 - else: - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warning( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - # First we tokenize all the texts. column_names = raw_datasets["train"].column_names text_column_name = "text" if "text" in column_names else column_names[0] @@ -396,13 +370,21 @@ def main(): desc="Running tokenizer on dataset", ) - block_size = tokenizer.model_max_length - if block_size > 1024: - logger.warning( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can reduce that value by passing --block_size xxx." - ) - block_size = 1024 + if data_args.block_size is None: + block_size = tokenizer.model_max_length + if block_size > 1024: + logger.warning( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --block_size xxx." + ) + block_size = 1024 + else: + if data_args.block_size > tokenizer.model_max_length: + logger.warning( + f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model" + f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}." + ) + block_size = min(data_args.block_size, tokenizer.model_max_length) # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples):