Truncate max length if needed in all examples (#10034)

This commit is contained in:
Sylvain Gugger
2021-02-08 05:03:55 -05:00
committed by GitHub
parent 45aaf5f7ab
commit b01483faa0
6 changed files with 68 additions and 31 deletions

View File

@@ -286,6 +286,22 @@ def main():
context_name = "sent1"
question_header_name = "sent2"
if data_args.max_seq_length is None:
max_seq_length = tokenizer.model_max_length
if max_seq_length > 1024:
logger.warn(
f"The tokenizer picked seems to have a very large `model_max_length` ({tokenizer.model_max_length}). "
"Picking 1024 instead. You can change that default value by passing --max_seq_length xxx."
)
max_seq_length = 1024
else:
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
# Preprocessing the datasets.
def preprocess_function(examples):
first_sentences = [[context] * 4 for context in examples[context_name]]
@@ -303,7 +319,7 @@ def main():
first_sentences,
second_sentences,
truncation=True,
max_length=data_args.max_seq_length,
max_length=max_seq_length,
padding="max_length" if data_args.pad_to_max_length else False,
)
# Un-flatten