Fix max length in run_plm script (#8738)
This commit is contained in:
@@ -93,11 +93,11 @@ class DataTrainingArguments:
|
|||||||
overwrite_cache: bool = field(
|
overwrite_cache: bool = field(
|
||||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||||
)
|
)
|
||||||
max_seq_length: Optional[int] = field(
|
max_seq_length: int = field(
|
||||||
default=None,
|
default=512,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||||
"than this will be truncated. Default to the max input length of the model."
|
"than this will be truncated."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: Optional[int] = field(
|
||||||
@@ -286,15 +286,12 @@ def main():
|
|||||||
load_from_cache_file=not data_args.overwrite_cache,
|
load_from_cache_file=not data_args.overwrite_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.max_seq_length is None:
|
if data_args.max_seq_length > tokenizer.model_max_length:
|
||||||
max_seq_length = tokenizer.model_max_length
|
logger.warn(
|
||||||
else:
|
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
||||||
if data_args.max_seq_length > tokenizer.model_max_length:
|
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
||||||
logger.warn(
|
)
|
||||||
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
||||||
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
|
||||||
)
|
|
||||||
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
|
||||||
|
|
||||||
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
# Main data processing function that will concatenate all texts from our dataset and generate chunks of
|
||||||
# max_seq_length.
|
# max_seq_length.
|
||||||
|
|||||||
Reference in New Issue
Block a user