Revert "Gradient accumulation for TFTrainer (#9585)"

This reverts commit 3f40070c88.
This commit is contained in:
Kiyoung Kim
2021-01-15 00:16:39 +09:00
committed by Lysandre
parent 82498cbc37
commit 85788bae5c

View File

@@ -638,9 +638,7 @@ class TFTrainer:
reduced_features = { reduced_features = {
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items() k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
} }
reduced_labels = { reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
}
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch) self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
@@ -652,13 +650,9 @@ class TFTrainer:
for k, ft in features.items() for k, ft in features.items()
} }
labels = { labels = tf.concat(
k: tf.concat( [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
[lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
axis=0,
) )
for k, lbl in labels.items()
}
gradients = self.gradient_accumulator.gradients gradients = self.gradient_accumulator.gradients
gradients = [ gradients = [