[Wav2Vec2] Flax - Adapt wav2vec2 script (#12520)
* fix_torch_device_generate_test * remove @ * adapt flax pretrain script
This commit is contained in:
committed by
GitHub
parent
4605b2b8ec
commit
7d6285a921
@@ -64,6 +64,12 @@ class ModelArguments:
|
|||||||
gumbel_temperature_decay: Optional[float] = field(
|
gumbel_temperature_decay: Optional[float] = field(
|
||||||
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
|
default=0.999995, metadata={"help": "Decay of gumbel temperature during training."}
|
||||||
)
|
)
|
||||||
|
dtype: Optional[str] = field(
|
||||||
|
default="float32",
|
||||||
|
metadata={
|
||||||
|
"help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@flax.struct.dataclass
|
@flax.struct.dataclass
|
||||||
@@ -197,7 +203,7 @@ def configure_logger(model_args: ModelArguments, training_args: TrainingArgument
|
|||||||
logger.setLevel(logging_level)
|
logger.setLevel(logging_level)
|
||||||
|
|
||||||
|
|
||||||
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
||||||
summary_writer.scalar("train_time", train_time, step)
|
summary_writer.scalar("train_time", train_time, step)
|
||||||
|
|
||||||
train_metrics = get_metrics(train_metrics)
|
train_metrics = get_metrics(train_metrics)
|
||||||
@@ -206,6 +212,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|||||||
for i, val in enumerate(vals):
|
for i, val in enumerate(vals):
|
||||||
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def write_eval_metric(summary_writer, eval_metrics, step):
|
||||||
for metric_name, value in eval_metrics.items():
|
for metric_name, value in eval_metrics.items():
|
||||||
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
||||||
|
|
||||||
@@ -342,9 +350,7 @@ def main():
|
|||||||
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
"PreTraining is only supported for ``config.do_stable_layer_norm=True`` and ``config.feat_extract_norm='layer'"
|
||||||
)
|
)
|
||||||
|
|
||||||
model = FlaxWav2Vec2ForPreTraining(
|
model = FlaxWav2Vec2ForPreTraining(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
||||||
config, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
|
data_collator = FlaxDataCollatorForWav2Vec2Pretraining(
|
||||||
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
|
model=model, feature_extractor=feature_extractor, pad_to_multiple_of=data_args.pad_to_multiple_of
|
||||||
@@ -501,11 +507,11 @@ def main():
|
|||||||
state = jax_utils.replicate(state)
|
state = jax_utils.replicate(state)
|
||||||
|
|
||||||
train_time = 0
|
train_time = 0
|
||||||
|
train_metrics = []
|
||||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||||
for epoch in epochs:
|
for epoch in epochs:
|
||||||
# ======================== Training ================================
|
# ======================== Training ================================
|
||||||
train_start = time.time()
|
train_start = time.time()
|
||||||
train_metrics = []
|
|
||||||
|
|
||||||
# Create sampling rng
|
# Create sampling rng
|
||||||
rng, input_rng = jax.random.split(rng)
|
rng, input_rng = jax.random.split(rng)
|
||||||
@@ -516,7 +522,7 @@ def main():
|
|||||||
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
||||||
|
|
||||||
# Gather the indexes for creating the batch and do a training step
|
# Gather the indexes for creating the batch and do a training step
|
||||||
for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
||||||
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
|
samples = [vectorized_datasets["train"][int(idx)] for idx in batch_idx]
|
||||||
model_inputs = data_collator(samples)
|
model_inputs = data_collator(samples)
|
||||||
model_inputs = shard(model_inputs.data)
|
model_inputs = shard(model_inputs.data)
|
||||||
@@ -527,12 +533,21 @@ def main():
|
|||||||
)
|
)
|
||||||
train_metrics.append(train_metric)
|
train_metrics.append(train_metric)
|
||||||
|
|
||||||
|
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
||||||
|
|
||||||
|
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
||||||
|
# Save metrics
|
||||||
|
train_metric = jax_utils.unreplicate(train_metric)
|
||||||
train_time += time.time() - train_start
|
train_time += time.time() - train_start
|
||||||
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
|
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
||||||
|
|
||||||
epochs.write(
|
epochs.write(
|
||||||
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
f"Step... ({cur_step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_metrics = []
|
||||||
|
|
||||||
# ======================== Evaluating ==============================
|
# ======================== Evaluating ==============================
|
||||||
num_eval_samples = len(vectorized_datasets["validation"])
|
num_eval_samples = len(vectorized_datasets["validation"])
|
||||||
eval_samples_idx = jnp.arange(num_eval_samples)
|
eval_samples_idx = jnp.arange(num_eval_samples)
|
||||||
@@ -560,7 +575,7 @@ def main():
|
|||||||
# Save metrics
|
# Save metrics
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
|
cur_step = epoch * (len(vectorized_datasets["train"]) // train_batch_size)
|
||||||
write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
|
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
||||||
|
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
if jax.process_index() == 0:
|
if jax.process_index() == 0:
|
||||||
|
|||||||
Reference in New Issue
Block a user