From b01483faa0cfb57369cbce153c671dbe48cc0638 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 8 Feb 2021 05:03:55 -0500 Subject: [PATCH] Truncate max length if needed in all examples (#10034) --- examples/language-modeling/run_mlm.py | 34 +++++++++---------- examples/language-modeling/run_plm.py | 16 ++++----- examples/multiple-choice/run_swag.py | 18 +++++++++- examples/question-answering/run_qa.py | 11 ++++-- .../question-answering/run_qa_beam_search.py | 11 ++++-- examples/text-classification/run_glue.py | 9 ++++- 6 files changed, 68 insertions(+), 31 deletions(-) diff --git a/examples/language-modeling/run_mlm.py b/examples/language-modeling/run_mlm.py index 626a70203e..437fb356b7 100755 --- a/examples/language-modeling/run_mlm.py +++ b/examples/language-modeling/run_mlm.py @@ -303,6 +303,22 @@ def main(): column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] + if data_args.max_seq_length is None: + max_seq_length = tokenizer.model_max_length + if max_seq_length > 1024: + logger.warn( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." + ) + max_seq_length = 1024 + else: + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + if data_args.line_by_line: # When using line_by_line, we just tokenize each nonempty line. padding = "max_length" if data_args.pad_to_max_length else False @@ -314,7 +330,7 @@ def main(): examples["text"], padding=padding, truncation=True, - max_length=data_args.max_seq_length, + max_length=max_seq_length, # We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it # receives the `special_tokens_mask`. return_special_tokens_mask=True, @@ -342,22 +358,6 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, ) - if data_args.max_seq_length is None: - max_seq_length = tokenizer.model_max_length - if max_seq_length > 1024: - logger.warn( - f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " - "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." - ) - max_seq_length = 1024 - else: - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warn( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - # Main data processing function that will concatenate all texts from our dataset and generate chunks of # max_seq_length. def group_texts(examples): diff --git a/examples/language-modeling/run_plm.py b/examples/language-modeling/run_plm.py index ebbea9f456..b44748c1dd 100755 --- a/examples/language-modeling/run_plm.py +++ b/examples/language-modeling/run_plm.py @@ -300,6 +300,13 @@ def main(): column_names = datasets["validation"].column_names text_column_name = "text" if "text" in column_names else column_names[0] + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + if data_args.line_by_line: # When using line_by_line, we just tokenize each nonempty line. padding = "max_length" if data_args.pad_to_max_length else False @@ -307,7 +314,7 @@ def main(): def tokenize_function(examples): # Remove empty lines examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()] - return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length) + return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length) tokenized_datasets = datasets.map( tokenize_function, @@ -329,13 +336,6 @@ def main(): load_from_cache_file=not data_args.overwrite_cache, ) - if data_args.max_seq_length > tokenizer.model_max_length: - logger.warn( - f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" - f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." - ) - max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) - # Main data processing function that will concatenate all texts from our dataset and generate chunks of # max_seq_length. def group_texts(examples): diff --git a/examples/multiple-choice/run_swag.py b/examples/multiple-choice/run_swag.py index 28b2605f0c..efe95247dc 100755 --- a/examples/multiple-choice/run_swag.py +++ b/examples/multiple-choice/run_swag.py @@ -286,6 +286,22 @@ def main(): context_name = "sent1" question_header_name = "sent2" + if data_args.max_seq_length is None: + max_seq_length = tokenizer.model_max_length + if max_seq_length > 1024: + logger.warn( + f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). " + "Picking 1024 instead. You can change that default value by passing --max_seq_length xxx." + ) + max_seq_length = 1024 + else: + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + # Preprocessing the datasets. def preprocess_function(examples): first_sentences = [[context] * 4 for context in examples[context_name]] @@ -303,7 +319,7 @@ def main(): first_sentences, second_sentences, truncation=True, - max_length=data_args.max_seq_length, + max_length=max_seq_length, padding="max_length" if data_args.pad_to_max_length else False, ) # Un-flatten diff --git a/examples/question-answering/run_qa.py b/examples/question-answering/run_qa.py index 2163f7fb71..5df70a3352 100755 --- a/examples/question-answering/run_qa.py +++ b/examples/question-answering/run_qa.py @@ -277,6 +277,13 @@ def main(): # Padding side determines if we do (question|context) or (context|question). pad_on_right = tokenizer.padding_side == "right" + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + # Training preprocessing def prepare_train_features(examples): # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results @@ -286,7 +293,7 @@ def main(): examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", - max_length=data_args.max_seq_length, + max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, @@ -368,7 +375,7 @@ def main(): examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", - max_length=data_args.max_seq_length, + max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, diff --git a/examples/question-answering/run_qa_beam_search.py b/examples/question-answering/run_qa_beam_search.py index d81bff914b..0681b23c79 100755 --- a/examples/question-answering/run_qa_beam_search.py +++ b/examples/question-answering/run_qa_beam_search.py @@ -267,6 +267,13 @@ def main(): # Padding side determines if we do (question|context) or (context|question). pad_on_right = tokenizer.padding_side == "right" + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + # Training preprocessing def prepare_train_features(examples): # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results @@ -276,7 +283,7 @@ def main(): examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", - max_length=data_args.max_seq_length, + max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, @@ -381,7 +388,7 @@ def main(): examples[question_column_name if pad_on_right else context_column_name], examples[context_column_name if pad_on_right else question_column_name], truncation="only_second" if pad_on_right else "only_first", - max_length=data_args.max_seq_length, + max_length=max_seq_length, stride=data_args.doc_stride, return_overflowing_tokens=True, return_offsets_mapping=True, diff --git a/examples/text-classification/run_glue.py b/examples/text-classification/run_glue.py index 9fb1fd774d..a056f15df7 100755 --- a/examples/text-classification/run_glue.py +++ b/examples/text-classification/run_glue.py @@ -334,12 +334,19 @@ def main(): elif data_args.task_name is None and not is_regression: label_to_id = {v: i for i, v in enumerate(label_list)} + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + def preprocess_function(examples): # Tokenize the texts args = ( (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key]) ) - result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) # Map labels to IDs (not necessary for GLUE tasks) if label_to_id is not None and "label" in examples: