Gradient accumulation for TFTrainer (#9585)
* gradient accumulation for tftrainer * label naming Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * label naming Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -638,7 +638,9 @@ class TFTrainer:
|
|||||||
reduced_features = {
|
reduced_features = {
|
||||||
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
|
k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items()
|
||||||
}
|
}
|
||||||
reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas]
|
reduced_labels = {
|
||||||
|
k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items()
|
||||||
|
}
|
||||||
|
|
||||||
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
|
self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch)
|
||||||
|
|
||||||
@@ -650,9 +652,13 @@ class TFTrainer:
|
|||||||
for k, ft in features.items()
|
for k, ft in features.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
labels = tf.concat(
|
labels = {
|
||||||
[labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0
|
k: tf.concat(
|
||||||
)
|
[lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
for k, lbl in labels.items()
|
||||||
|
}
|
||||||
|
|
||||||
gradients = self.gradient_accumulator.gradients
|
gradients = self.gradient_accumulator.gradients
|
||||||
gradients = [
|
gradients = [
|
||||||
|
|||||||
Reference in New Issue
Block a user