Black preview (#17217)

* Black preview

* Fixup too!

* Fix check copies

* Use the same version as the CI

* Bump black
This commit is contained in:
Sylvain Gugger
2022-05-12 16:25:55 -04:00
committed by GitHub
parent 9bd67ac7bb
commit afe5d42d8d
578 changed files with 8274 additions and 3296 deletions

View File

@@ -75,8 +75,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
@@ -99,7 +100,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
@@ -141,8 +145,10 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=None,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated. Default to the max input length of the model."
)
},
)
preprocessing_num_workers: Optional[int] = field(
@@ -155,8 +161,10 @@ class DataTrainingArguments:
pad_to_max_length: bool = field(
default=False,
metadata={
"help": "Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
"help": (
"Whether to pad all samples to `max_seq_length`. "
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
)
},
)
line_by_line: bool = field(
@@ -575,7 +583,8 @@ if __name__ == "__main__":
if step % training_args.logging_steps == 0 and step > 0:
steps.write(
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
f" {train_metric['learning_rate'].mean()})"
)
train_time += time.time() - train_start
if has_tensorboard and jax.process_index() == 0:
@@ -604,7 +613,10 @@ if __name__ == "__main__":
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
steps.desc = (
f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc:"
f" {eval_metrics['accuracy']})"
)
if has_tensorboard and jax.process_index() == 0:
write_eval_metric(summary_writer, eval_metrics, step)

View File

@@ -77,14 +77,18 @@ class ModelArguments:
text_model_name_or_path: str = field(
metadata={
"help": "The text model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
"help": (
"The text model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
)
},
)
vision_model_name_or_path: str = field(
metadata={
"help": "The vision model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
"help": (
"The vision model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
)
},
)
from_pt: bool = field(
@@ -107,7 +111,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
@@ -129,22 +136,28 @@ class DataTrainingArguments:
max_seq_length: Optional[int] = field(
default=72,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
"help": (
"The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
)
},
)
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
@@ -519,7 +532,8 @@ def main():
train_step_progress_bar.close()
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
# ======================== Evaluating ==============================

View File

@@ -69,8 +69,9 @@ class ModelArguments:
model_name_or_path: Optional[str] = field(
default=None,
metadata={
"help": "The model checkpoint for weights initialization."
"Don't set if you want to train a model from scratch."
"help": (
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
)
},
)
model_type: Optional[str] = field(
@@ -93,7 +94,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
@@ -118,15 +122,19 @@ class DataTrainingArguments:
max_train_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
)
},
)
max_eval_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
"help": (
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
"value if set."
)
},
)
overwrite_cache: bool = field(
@@ -141,9 +149,11 @@ class DataTrainingArguments:
block_size: Optional[int] = field(
default=None,
metadata={
"help": "Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
"help": (
"Optional input sequence length after tokenization. "
"The training dataset will be truncated in block of this size for training. "
"Default to the model max input length for single sentence inputs (take into account special tokens)."
)
},
)
overwrite_cache: bool = field(
@@ -334,7 +344,8 @@ def main():
# clm input could be much much longer than block_size
if "Token indices sequence length is longer than the" in cl.out:
tok_logger.warning(
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
" before being passed to the model."
)
return output
@@ -606,7 +617,8 @@ def main():
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
f" {train_metric['learning_rate']})"
)
train_metrics = []
@@ -632,7 +644,8 @@ def main():
eval_metrics["perplexity"] = float("inf")
logger.info(
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity:"
f" {eval_metrics['perplexity']}"
)
if cur_step % training_args.save_steps == 0 and cur_step > 0:

View File

@@ -64,7 +64,10 @@ class ModelArguments:
dtype: Optional[str] = field(
default="float32",
metadata={
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
"help": (
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
" `[float32, float16, bfloat16]`."
)
},
)
@@ -94,7 +97,9 @@ class DataTrainingArguments:
validation_split_name: Optional[str] = field(
default="validation",
metadata={
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
"help": (
"The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
)
},
)
speech_file_column: Optional[str] = field(
@@ -120,7 +125,10 @@ class DataTrainingArguments:
pad_to_multiple_of: Optional[int] = field(
default=1024,
metadata={
"help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
"help": (
"If set will pad the sequence to a multiple of the provided value. This is important to avoid"
" triggering recompilations on TPU"
)
},
)
@@ -357,7 +365,8 @@ def main():
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
raise ValueError(
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
" ``config.feat_extract_norm='layer'"
)
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
@@ -557,7 +566,8 @@ def main():
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
epochs.write(
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
f" {train_metric['learning_rate'].mean()})"
)
train_metrics = []
@@ -583,7 +593,8 @@ def main():
# Update progress bar
epochs.write(
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity:"
f" {eval_metrics['codevector_perplexity']})"
)
# Save metrics