Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -91,8 +91,10 @@ class ModelArguments:
|
||||
use_auth_token: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
"help": (
|
||||
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
||||
"with private models)."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -116,15 +118,12 @@ class DataTrainingArguments:
|
||||
validation_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on "
|
||||
"a jsonlines file."
|
||||
"help": "An optional input evaluation data file to evaluate the metrics (sacreblue) on a jsonlines file."
|
||||
},
|
||||
)
|
||||
test_file: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "An optional input test data file to evaluate the metrics (sacreblue) on " "a jsonlines file."
|
||||
},
|
||||
metadata={"help": "An optional input test data file to evaluate the metrics (sacreblue) on a jsonlines file."},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
||||
@@ -136,60 +135,76 @@ class DataTrainingArguments:
|
||||
max_source_length: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_target_length: Optional[int] = field(
|
||||
default=128,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
"help": (
|
||||
"The maximum total sequence length for target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
val_max_target_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
"help": (
|
||||
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
|
||||
"This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
|
||||
"during ``evaluate`` and ``predict``."
|
||||
)
|
||||
},
|
||||
)
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
"help": (
|
||||
"Whether to pad all samples to model maximum sentence length. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
|
||||
"efficient on GPU but very bad for TPU."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_predict_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
num_beams: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
"help": (
|
||||
"Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
|
||||
"which is used during ``evaluate`` and ``predict``."
|
||||
)
|
||||
},
|
||||
)
|
||||
ignore_pad_token_for_loss: bool = field(
|
||||
@@ -204,9 +219,11 @@ class DataTrainingArguments:
|
||||
forced_bos_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The token to force as the first generated token after the :obj:`decoder_start_token_id`."
|
||||
"Useful for multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token "
|
||||
"needs to be the target language token.(Usually it is the target language token)"
|
||||
"help": (
|
||||
"The token to force as the first generated token after the :obj:`decoder_start_token_id`.Useful for"
|
||||
" multilingual models like :doc:`mBART <../model_doc/mbart>` where the first generated token needs to"
|
||||
" be the target language token.(Usually it is the target language token)"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -95,41 +95,51 @@ def parse_args():
|
||||
"--num_beams",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Number of beams to use for evaluation. This argument will be "
|
||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
||||
help=(
|
||||
"Number of beams to use for evaluation. This argument will be "
|
||||
"passed to ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max_source_length",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="The maximum total input sequence length after "
|
||||
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded.",
|
||||
help=(
|
||||
"The maximum total input sequence length after "
|
||||
"tokenization.Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_target_length",
|
||||
type=int,
|
||||
default=128,
|
||||
help="The maximum total sequence length for target text after "
|
||||
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||
"during ``evaluate`` and ``predict``.",
|
||||
help=(
|
||||
"The maximum total sequence length for target text after "
|
||||
"tokenization. Sequences longer than this will be truncated, sequences shorter will be padded."
|
||||
"during ``evaluate`` and ``predict``."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--val_max_target_length",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The maximum total sequence length for validation "
|
||||
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
|
||||
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
|
||||
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``.",
|
||||
help=(
|
||||
"The maximum total sequence length for validation "
|
||||
"target text after tokenization.Sequences longer than this will be truncated, sequences shorter will be "
|
||||
"padded. Will default to `max_target_length`.This argument is also used to override the ``max_length`` "
|
||||
"param of ``model.generate``, which is used during ``evaluate`` and ``predict``."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pad_to_max_length",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="Whether to pad all samples to model maximum sentence "
|
||||
"length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
|
||||
"efficient on GPU but very bad for TPU.",
|
||||
help=(
|
||||
"Whether to pad all samples to model maximum sentence "
|
||||
"length. If False, will pad the samples dynamically when batching to the maximum length in the batch. More"
|
||||
"efficient on GPU but very bad for TPU."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_file", type=str, default=None, help="A csv or a json file containing the validation data."
|
||||
@@ -138,7 +148,7 @@ def parse_args():
|
||||
"--ignore_pad_token_for_loss",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Whether to ignore the tokens corresponding to " "padded labels in the loss computation or not.",
|
||||
help="Whether to ignore the tokens corresponding to padded labels in the loss computation or not.",
|
||||
)
|
||||
parser.add_argument("--source_lang", type=str, default=None, help="Source language id for translation.")
|
||||
parser.add_argument("--target_lang", type=str, default=None, help="Target language id for translation.")
|
||||
@@ -146,7 +156,7 @@ def parse_args():
|
||||
"--source_prefix",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A prefix to add before every source text " "(useful for T5 models).",
|
||||
help="A prefix to add before every source text (useful for T5 models).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--preprocessing_num_workers",
|
||||
|
||||
Reference in New Issue
Block a user