Truncate max length if needed in all examples (#10034)
This commit is contained in:
@@ -334,12 +334,19 @@ def main():
|
||||
elif data_args.task_name is None and not is_regression:
|
||||
label_to_id = {v: i for i, v in enumerate(label_list)}
|
||||
|
||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||
logger.warn(
|
||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||
)
|
||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||
|
||||
def preprocess_function(examples):
|
||||
# Tokenize the texts
|
||||
args = (
|
||||
(examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
|
||||
)
|
||||
result = tokenizer(*args, padding=padding, max_length=data_args.max_seq_length, truncation=True)
|
||||
result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
|
||||
|
||||
# Map labels to IDs (not necessary for GLUE tasks)
|
||||
if label_to_id is not None and "label" in examples:
|
||||
|
||||
Reference in New Issue
Block a user