Truncate max length if needed in all examples (#10034)
This commit is contained in:
@@ -303,6 +303,22 @@ def main():
|
|||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
|
|
||||||
|
if data_args.max_seq_length is None:
|
||||||
|
max_seq_length = tokenizer.model_max_length
|
||||||
|
if max_seq_length > 1024:
|
||||||
|
logger.warn(
|
||||||
|
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||||
|
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
|
||||||
|
)
|
||||||
|
max_seq_length = 1024
|
||||||
|
else:
|
||||||
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
|
logger.warn(
|
||||||
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||||
|
)
|
||||||
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||||
|
|
||||||
if data_args.line_by_line:
|
if data_args.line_by_line:
|
||||||
# When using line_by_line, we just tokenize each nonempty line.
|
# When using line_by_line, we just tokenize each nonempty line.
|
||||||
padding = "max_length" if data_args.pad_to_max_length else False
|
padding = "max_length" if data_args.pad_to_max_length else False
|
||||||
@@ -314,7 +330,7 @@ def main():
|
|||||||
examples["text"],
|
examples["text"],
|
||||||
padding=padding,
|
padding=padding,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=data_args.max_seq_length,
|
max_length=max_seq_length,
|
||||||
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
|
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
|
||||||
# receives the `special_tokens_mask`.
|
# receives the `special_tokens_mask`.
|
||||||
return_special_tokens_mask=True,
|
return_special_tokens_mask=True,
|
||||||
@@ -342,22 +358,6 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.max_seq_length is None:
|
|
||||||
max_seq_length = tokenizer.model_max_length
|
|
||||||
if max_seq_length > 1024:
|
|
||||||
logger.warn(
|
|
||||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
|
||||||
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
|
|
||||||
)
|
|
||||||
max_seq_length = 1024
|
|
||||||
else:
|
|
||||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
|
||||||
logger.warn(
|
|
||||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
|
||||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
|
||||||
)
|
|
||||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
|
||||||
|
|
||||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||||
# max_seq_length.
|
# max_seq_length.
|
||||||
def group_texts(examples):
|
def group_texts(examples):
|
||||||
|
|||||||
@@ -300,6 +300,13 @@ def main():
|
|||||||
column_names = datasets["validation"].column_names
|
column_names = datasets["validation"].column_names
|
||||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||||
|
|
||||||
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
|
logger.warn(
|
||||||
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||||
|
)
|
||||||
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||||
|
|
||||||
if data_args.line_by_line:
|
if data_args.line_by_line:
|
||||||
# When using line_by_line, we just tokenize each nonempty line.
|
# When using line_by_line, we just tokenize each nonempty line.
|
||||||
padding = "max_length" if data_args.pad_to_max_length else False
|
padding = "max_length" if data_args.pad_to_max_length else False
|
||||||
@@ -307,7 +314,7 @@ def main():
|
|||||||
def tokenize_function(examples):
|
def tokenize_function(examples):
|
||||||
# Remove empty lines
|
# Remove empty lines
|
||||||
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
|
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
|
||||||
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length)
|
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length)
|
||||||
|
|
||||||
tokenized_datasets = datasets.map(
|
tokenized_datasets = datasets.map(
|
||||||
tokenize_function,
|
tokenize_function,
|
||||||
@@ -329,13 +336,6 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
|
||||||
logger.warn(
|
|
||||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
|
||||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
|
||||||
)
|
|
||||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
|
||||||
|
|
||||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||||
# max_seq_length.
|
# max_seq_length.
|
||||||
def group_texts(examples):
|
def group_texts(examples):
|
||||||
|
|||||||
@@ -286,6 +286,22 @@ def main():
|
|||||||
context_name = "sent1"
|
context_name = "sent1"
|
||||||
question_header_name = "sent2"
|
question_header_name = "sent2"
|
||||||
|
|
||||||
|
if data_args.max_seq_length is None:
|
||||||
|
max_seq_length = tokenizer.model_max_length
|
||||||
|
if max_seq_length > 1024:
|
||||||
|
logger.warn(
|
||||||
|
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||||
|
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
|
||||||
|
)
|
||||||
|
max_seq_length = 1024
|
||||||
|
else:
|
||||||
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
|
logger.warn(
|
||||||
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||||
|
)
|
||||||
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||||
|
|
||||||
# Preprocessing the datasets.
|
# Preprocessing the datasets.
|
||||||
def preprocess_function(examples):
|
def preprocess_function(examples):
|
||||||
first_sentences = [[context] * 4 for context in examples[context_name]]
|
first_sentences = [[context] * 4 for context in examples[context_name]]
|
||||||
@@ -303,7 +319,7 @@ def main():
|
|||||||
first_sentences,
|
first_sentences,
|
||||||
second_sentences,
|
second_sentences,
|
||||||
truncation=True,
|
truncation=True,
|
||||||
max_length=data_args.max_seq_length,
|
max_length=max_seq_length,
|
||||||
padding="max_length" if data_args.pad_to_max_length else False,
|
padding="max_length" if data_args.pad_to_max_length else False,
|
||||||
)
|
)
|
||||||
# Un-flatten
|
# Un-flatten
|
||||||
|
|||||||
@@ -277,6 +277,13 @@ def main():
|
|||||||
# Padding side determines if we do (question|context) or (context|question).
|
# Padding side determines if we do (question|context) or (context|question).
|
||||||
pad_on_right = tokenizer.padding_side == "right"
|
pad_on_right = tokenizer.padding_side == "right"
|
||||||
|
|
||||||
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
|
logger.warn(
|
||||||
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||||
|
)
|
||||||
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||||
|
|
||||||
# Training preprocessing
|
# Training preprocessing
|
||||||
def prepare_train_features(examples):
|
def prepare_train_features(examples):
|
||||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||||
@@ -286,7 +293,7 @@ def main():
|
|||||||
examples[question_column_name if pad_on_right else context_column_name],
|
examples[question_column_name if pad_on_right else context_column_name],
|
||||||
examples[context_column_name if pad_on_right else question_column_name],
|
examples[context_column_name if pad_on_right else question_column_name],
|
||||||
truncation="only_second" if pad_on_right else "only_first",
|
truncation="only_second" if pad_on_right else "only_first",
|
||||||
max_length=data_args.max_seq_length,
|
max_length=max_seq_length,
|
||||||
stride=data_args.doc_stride,
|
stride=data_args.doc_stride,
|
||||||
return_overflowing_tokens=True,
|
return_overflowing_tokens=True,
|
||||||
return_offsets_mapping=True,
|
return_offsets_mapping=True,
|
||||||
@@ -368,7 +375,7 @@ def main():
|
|||||||
examples[question_column_name if pad_on_right else context_column_name],
|
examples[question_column_name if pad_on_right else context_column_name],
|
||||||
examples[context_column_name if pad_on_right else question_column_name],
|
examples[context_column_name if pad_on_right else question_column_name],
|
||||||
truncation="only_second" if pad_on_right else "only_first",
|
truncation="only_second" if pad_on_right else "only_first",
|
||||||
max_length=data_args.max_seq_length,
|
max_length=max_seq_length,
|
||||||
stride=data_args.doc_stride,
|
stride=data_args.doc_stride,
|
||||||
return_overflowing_tokens=True,
|
return_overflowing_tokens=True,
|
||||||
return_offsets_mapping=True,
|
return_offsets_mapping=True,
|
||||||
|
|||||||
@@ -267,6 +267,13 @@ def main():
|
|||||||
# Padding side determines if we do (question|context) or (context|question).
|
# Padding side determines if we do (question|context) or (context|question).
|
||||||
pad_on_right = tokenizer.padding_side == "right"
|
pad_on_right = tokenizer.padding_side == "right"
|
||||||
|
|
||||||
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
|
logger.warn(
|
||||||
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||||
|
)
|
||||||
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||||
|
|
||||||
# Training preprocessing
|
# Training preprocessing
|
||||||
def prepare_train_features(examples):
|
def prepare_train_features(examples):
|
||||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||||
@@ -276,7 +283,7 @@ def main():
|
|||||||
examples[question_column_name if pad_on_right else context_column_name],
|
examples[question_column_name if pad_on_right else context_column_name],
|
||||||
examples[context_column_name if pad_on_right else question_column_name],
|
examples[context_column_name if pad_on_right else question_column_name],
|
||||||
truncation="only_second" if pad_on_right else "only_first",
|
truncation="only_second" if pad_on_right else "only_first",
|
||||||
max_length=data_args.max_seq_length,
|
max_length=max_seq_length,
|
||||||
stride=data_args.doc_stride,
|
stride=data_args.doc_stride,
|
||||||
return_overflowing_tokens=True,
|
return_overflowing_tokens=True,
|
||||||
return_offsets_mapping=True,
|
return_offsets_mapping=True,
|
||||||
@@ -381,7 +388,7 @@ def main():
|
|||||||
examples[question_column_name if pad_on_right else context_column_name],
|
examples[question_column_name if pad_on_right else context_column_name],
|
||||||
examples[context_column_name if pad_on_right else question_column_name],
|
examples[context_column_name if pad_on_right else question_column_name],
|
||||||
truncation="only_second" if pad_on_right else "only_first",
|
truncation="only_second" if pad_on_right else "only_first",
|
||||||
max_length=data_args.max_seq_length,
|
max_length=max_seq_length,
|
||||||
stride=data_args.doc_stride,
|
stride=data_args.doc_stride,
|
||||||
return_overflowing_tokens=True,
|
return_overflowing_tokens=True,
|
||||||
return_offsets_mapping=True,
|
return_offsets_mapping=True,
|
||||||
|
|||||||
@@ -334,12 +334,19 @@ def main():
|
|||||||
elif data_args.task_name is None and not is_regression:
|
elif data_args.task_name is None and not is_regression:
|
||||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||||
|
|
||||||
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
|
logger.warn(
|
||||||
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||||
|
)
|
||||||
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||||
|
|
||||||
def preprocess_function(examples):
|
def preprocess_function(examples):
|
||||||
# Tokenize the texts
|
# Tokenize the texts
|
||||||
args = (
|
args = (
|
||||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||||
)
|
)
|
||||||
result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True)
|
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
||||||
|
|
||||||
# Map labels to IDs (not necessary for GLUE tasks)
|
# Map labels to IDs (not necessary for GLUE tasks)
|
||||||
if label_to_id is not None and "label" in examples:
|
if label_to_id is not None and "label" in examples:
|
||||||
|
|||||||
Reference in New Issue
Block a user