Black preview (#17217)

* Black preview

* Fixup too!

* Fix check copies

* Use the same version as the CI

* Bump black
This commit is contained in:
Sylvain Gugger
2022-05-12 16:25:55 -04:00
committed by GitHub
parent 9bd67ac7bb
commit afe5d42d8d
578 changed files with 8274 additions and 3296 deletions

View File

@@ -175,14 +175,19 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
use_auth_token: bool = field(
default=False,
metadata={
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
"help": (
"Will use the token generated when running `transformers-cli login` (necessary to use this script "
"with private models)."
)
},
)
@@ -222,38 +227,48 @@ class DataTrainingArguments:
max_target_length: Optional[int] = field(
default=128,
metadata={
"help": "The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
"help": (
"The maximum total sequence length for target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
val_max_target_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
"during evaluation."
"help": (
"The maximum total sequence length for validation target text after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
"This argument is also used to override the `max_length` param of `model.generate`, which is used "
"during evaluation."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
)
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -266,8 +281,10 @@ class DataTrainingArguments:
num_beams: Optional[int] = field(
default=None,
metadata={
"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
"which is used during evaluation."
"help": (
"Number of beams to use for evaluation. This argument will be passed to `model.generate`, "
"which is used during evaluation."
)
},
)
overwrite_cache: bool = field(
@@ -623,7 +640,7 @@ def main():
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
if training_args.block_size % train_batch_size > 0 or training_args.block_size % eval_batch_size > 0:
raise ValueError(
f"`training_args.block_size` needs to be a multiple of the global train/eval batch size."
"`training_args.block_size` needs to be a multiple of the global train/eval batch size."
f"Got {training_args.block_size}, {train_batch_size} and {eval_batch_size} respectively instead."
)
@@ -1136,7 +1153,7 @@ def main():
)
# train
for (batch_idx, _) in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
for batch_idx, _ in enumerate(tqdm(range(steps_per_epoch), desc="Training...", position=1, leave=False)):
cur_step += 1
batch = next(train_batches)
@@ -1150,7 +1167,10 @@ def main():
if training_args.logging_steps > 0 and cur_step % training_args.logging_steps == 0:
_train_metric = unreplicate(train_metric)
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} | Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
desc = (
f"Epoch... ({epoch + 1}/{num_epochs} | Step: {cur_step} | Loss: {_train_metric['loss']} |"
f" Learning Rate: {_train_metric['learning_rate']} | Time per step: {time_per_step})"
)
epochs.desc = desc
epochs.write(desc)