Truncate max length if needed in all examples (#10034)

This commit is contained in:
Sylvain Gugger
2021-02-08 05:03:55 -05:00
committed by GitHub
parent 45aaf5f7ab
commit b01483faa0
6 changed files with 68 additions and 31 deletions

View File

@@ -334,12 +334,19 @@ def main():
elif data_args.task_name is None and not is_regression:
label_to_id = {v: i for i, v in enumerate(label_list)}
if data_args.max_seq_length > tokenizer.model_max_length:
logger.warn(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
)
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
def preprocess_function(examples):
# Tokenize the texts
args = (
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
)
result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True)
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples: