Truncate max length if needed in all examples (#10034)
This commit is contained in:
@@ -303,6 +303,22 @@ def main():
|
||||
column_names = datasets["validation"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
if data_args.max_seq_length is None:
|
||||
max_seq_length = tokenizer.model_max_length
|
||||
if max_seq_length > 1024:
|
||||
logger.warn(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
|
||||
)
|
||||
max_seq_length = 1024
|
||||
else:
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
if data_args.line_by_line:
|
||||
# When using line_by_line, we just tokenize each nonempty line.
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
@@ -314,7 +330,7 @@ def main():
|
||||
examples["text"],
|
||||
padding=padding,
|
||||
truncation=True,
|
||||
max_length=data_args.max_seq_length,
|
||||
max_length=max_seq_length,
|
||||
# We use this option because DataCollatorForLanguageModeling (see below) is more efficient when it
|
||||
# receives the `special_tokens_mask`.
|
||||
return_special_tokens_mask=True,
|
||||
@@ -342,22 +358,6 @@ def main():
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if data_args.max_seq_length is None:
|
||||
max_seq_length = tokenizer.model_max_length
|
||||
if max_seq_length > 1024:
|
||||
logger.warn(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
|
||||
)
|
||||
max_seq_length = 1024
|
||||
else:
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||
# max_seq_length.
|
||||
def group_texts(examples):
|
||||
|
||||
@@ -300,6 +300,13 @@ def main():
|
||||
column_names = datasets["validation"].column_names
|
||||
text_column_name = "text" if "text" in column_names else column_names[0]
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
if data_args.line_by_line:
|
||||
# When using line_by_line, we just tokenize each nonempty line.
|
||||
padding = "max_length" if data_args.pad_to_max_length else False
|
||||
@@ -307,7 +314,7 @@ def main():
|
||||
def tokenize_function(examples):
|
||||
# Remove empty lines
|
||||
examples["text"] = [line for line in examples["text"] if len(line) > 0 and not line.isspace()]
|
||||
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=data_args.max_seq_length)
|
||||
return tokenizer(examples["text"], padding=padding, truncation=True, max_length=max_seq_length)
|
||||
|
||||
tokenized_datasets = datasets.map(
|
||||
tokenize_function,
|
||||
@@ -329,13 +336,6 @@ def main():
|
||||
load_from_cache_file=not data_args.overwrite_cache,
|
||||
)
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||
# max_seq_length.
|
||||
def group_texts(examples):
|
||||
|
||||
@@ -286,6 +286,22 @@ def main():
|
||||
context_name = "sent1"
|
||||
question_header_name = "sent2"
|
||||
|
||||
if data_args.max_seq_length is None:
|
||||
max_seq_length = tokenizer.model_max_length
|
||||
if max_seq_length > 1024:
|
||||
logger.warn(
|
||||
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
|
||||
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
|
||||
)
|
||||
max_seq_length = 1024
|
||||
else:
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
def preprocess_function(examples):
|
||||
first_sentences = [[context] * 4 for context in examples[context_name]]
|
||||
@@ -303,7 +319,7 @@ def main():
|
||||
first_sentences,
|
||||
second_sentences,
|
||||
truncation=True,
|
||||
max_length=data_args.max_seq_length,
|
||||
max_length=max_seq_length,
|
||||
padding="max_length" if data_args.pad_to_max_length else False,
|
||||
)
|
||||
# Un-flatten
|
||||
|
||||
@@ -277,6 +277,13 @@ def main():
|
||||
# Padding side determines if we do (question|context) or (context|question).
|
||||
pad_on_right = tokenizer.padding_side == "right"
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
# Training preprocessing
|
||||
def prepare_train_features(examples):
|
||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||
@@ -286,7 +293,7 @@ def main():
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=data_args.max_seq_length,
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
@@ -368,7 +375,7 @@ def main():
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=data_args.max_seq_length,
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
|
||||
@@ -267,6 +267,13 @@ def main():
|
||||
# Padding side determines if we do (question|context) or (context|question).
|
||||
pad_on_right = tokenizer.padding_side == "right"
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
# Training preprocessing
|
||||
def prepare_train_features(examples):
|
||||
# Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
|
||||
@@ -276,7 +283,7 @@ def main():
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=data_args.max_seq_length,
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
@@ -381,7 +388,7 @@ def main():
|
||||
examples[question_column_name if pad_on_right else context_column_name],
|
||||
examples[context_column_name if pad_on_right else question_column_name],
|
||||
truncation="only_second" if pad_on_right else "only_first",
|
||||
max_length=data_args.max_seq_length,
|
||||
max_length=max_seq_length,
|
||||
stride=data_args.doc_stride,
|
||||
return_overflowing_tokens=True,
|
||||
return_offsets_mapping=True,
|
||||
|
||||
@@ -334,12 +334,19 @@ def main():
|
||||
elif data_args.task_name is None and not is_regression:
|
||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True)
|
||||
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
||||
|
||||
# Map labels to IDs (not necessary for GLUE tasks)
|
||||
if label_to_id is not None and "label" in examples:
|
||||
|
||||
Reference in New Issue
Block a user