Black preview (#17217)
* Black preview * Fixup too! * Fix check copies * Use the same version as the CI * Bump black
This commit is contained in:
@@ -75,8 +75,9 @@ class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
"help": (
|
||||
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
@@ -99,7 +100,10 @@ class ModelArguments:
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
"help": (
|
||||
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||
" `[float32, float16, bfloat16]`."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -141,8 +145,10 @@ class DataTrainingArguments:
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated. Default to the max input length of the model."
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated. Default to the max input length of the model."
|
||||
)
|
||||
},
|
||||
)
|
||||
preprocessing_num_workers: Optional[int] = field(
|
||||
@@ -155,8 +161,10 @@ class DataTrainingArguments:
|
||||
pad_to_max_length: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
"help": (
|
||||
"Whether to pad all samples to `max_seq_length`. "
|
||||
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
||||
)
|
||||
},
|
||||
)
|
||||
line_by_line: bool = field(
|
||||
@@ -575,7 +583,8 @@ if __name__ == "__main__":
|
||||
|
||||
if step % training_args.logging_steps == 0 and step > 0:
|
||||
steps.write(
|
||||
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
||||
f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
|
||||
f" {train_metric['learning_rate'].mean()})"
|
||||
)
|
||||
train_time += time.time() - train_start
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
@@ -604,7 +613,10 @@ if __name__ == "__main__":
|
||||
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
|
||||
|
||||
# Update progress bar
|
||||
steps.desc = f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
|
||||
steps.desc = (
|
||||
f"Step... ({step + 1}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc:"
|
||||
f" {eval_metrics['accuracy']})"
|
||||
)
|
||||
|
||||
if has_tensorboard and jax.process_index() == 0:
|
||||
write_eval_metric(summary_writer, eval_metrics, step)
|
||||
|
||||
@@ -77,14 +77,18 @@ class ModelArguments:
|
||||
|
||||
text_model_name_or_path: str = field(
|
||||
metadata={
|
||||
"help": "The text model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
"help": (
|
||||
"The text model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
vision_model_name_or_path: str = field(
|
||||
metadata={
|
||||
"help": "The vision model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
"help": (
|
||||
"The vision model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
from_pt: bool = field(
|
||||
@@ -107,7 +111,10 @@ class ModelArguments:
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
"help": (
|
||||
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||
" `[float32, float16, bfloat16]`."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -129,22 +136,28 @@ class DataTrainingArguments:
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=72,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
"help": (
|
||||
"The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
@@ -519,7 +532,8 @@ def main():
|
||||
|
||||
train_step_progress_bar.close()
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||
f" {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
# ======================== Evaluating ==============================
|
||||
|
||||
@@ -69,8 +69,9 @@ class ModelArguments:
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "The model checkpoint for weights initialization."
|
||||
"Don't set if you want to train a model from scratch."
|
||||
"help": (
|
||||
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
||||
)
|
||||
},
|
||||
)
|
||||
model_type: Optional[str] = field(
|
||||
@@ -93,7 +94,10 @@ class ModelArguments:
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
"help": (
|
||||
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||
" `[float32, float16, bfloat16]`."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -118,15 +122,19 @@ class DataTrainingArguments:
|
||||
max_train_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
max_eval_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
"help": (
|
||||
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
||||
"value if set."
|
||||
)
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
@@ -141,9 +149,11 @@ class DataTrainingArguments:
|
||||
block_size: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Optional input sequence length after tokenization. "
|
||||
"The training dataset will be truncated in block of this size for training. "
|
||||
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
||||
"help": (
|
||||
"Optional input sequence length after tokenization. "
|
||||
"The training dataset will be truncated in block of this size for training. "
|
||||
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
||||
)
|
||||
},
|
||||
)
|
||||
overwrite_cache: bool = field(
|
||||
@@ -334,7 +344,8 @@ def main():
|
||||
# clm input could be much much longer than block_size
|
||||
if "Token indices sequence length is longer than the" in cl.out:
|
||||
tok_logger.warning(
|
||||
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model."
|
||||
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
||||
" before being passed to the model."
|
||||
)
|
||||
return output
|
||||
|
||||
@@ -606,7 +617,8 @@ def main():
|
||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||
|
||||
epochs.write(
|
||||
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
||||
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate:"
|
||||
f" {train_metric['learning_rate']})"
|
||||
)
|
||||
|
||||
train_metrics = []
|
||||
@@ -632,7 +644,8 @@ def main():
|
||||
eval_metrics["perplexity"] = float("inf")
|
||||
|
||||
logger.info(
|
||||
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity: {eval_metrics['perplexity']}"
|
||||
f"Step... ({cur_step} | Eval loss: {eval_metrics['loss']} | Eval Perplexity:"
|
||||
f" {eval_metrics['perplexity']}"
|
||||
)
|
||||
|
||||
if cur_step % training_args.save_steps == 0 and cur_step > 0:
|
||||
|
||||
@@ -64,7 +64,10 @@ class ModelArguments:
|
||||
dtype: Optional[str] = field(
|
||||
default="float32",
|
||||
metadata={
|
||||
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||
"help": (
|
||||
"Floating-point format in which the model weights should be initialized and trained. Choose one of"
|
||||
" `[float32, float16, bfloat16]`."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -94,7 +97,9 @@ class DataTrainingArguments:
|
||||
validation_split_name: Optional[str] = field(
|
||||
default="validation",
|
||||
metadata={
|
||||
"help": "The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
|
||||
"help": (
|
||||
"The name of the validation data set split to use (via the datasets library). Defaults to 'validation'"
|
||||
)
|
||||
},
|
||||
)
|
||||
speech_file_column: Optional[str] = field(
|
||||
@@ -120,7 +125,10 @@ class DataTrainingArguments:
|
||||
pad_to_multiple_of: Optional[int] = field(
|
||||
default=1024,
|
||||
metadata={
|
||||
"help": "If set will pad the sequence to a multiple of the provided value. This is important to avoid triggering recompilations on TPU"
|
||||
"help": (
|
||||
"If set will pad the sequence to a multiple of the provided value. This is important to avoid"
|
||||
" triggering recompilations on TPU"
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -357,7 +365,8 @@ def main():
|
||||
|
||||
if not config.do_stable_layer_norm or config.feat_extract_norm != "layer":
|
||||
raise ValueError(
|
||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and"
|
||||
" ``config.feat_extract_norm='layer'"
|
||||
)
|
||||
|
||||
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||
@@ -557,7 +566,8 @@ def main():
|
||||
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||
|
||||
epochs.write(
|
||||
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
||||
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate:"
|
||||
f" {train_metric['learning_rate'].mean()})"
|
||||
)
|
||||
|
||||
train_metrics = []
|
||||
@@ -583,7 +593,8 @@ def main():
|
||||
|
||||
# Update progress bar
|
||||
epochs.write(
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity: {eval_metrics['codevector_perplexity']})"
|
||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {eval_metrics['loss']}, Perplexity:"
|
||||
f" {eval_metrics['codevector_perplexity']})"
|
||||
)
|
||||
|
||||
# Save metrics
|
||||
|
||||
Reference in New Issue
Block a user