From 3f40070c88de07169fe18b0b4c4003ef2858a284 Mon Sep 17 00:00:00 2001 From: Kiyoung Kim Date: Fri, 15 Jan 2021 00:16:39 +0900 Subject: [PATCH] Gradient accumulation for TFTrainer (#9585) * gradient accumulation for tftrainer * label naming Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * label naming Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/trainer_tf.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index ac75eb6223..e566082ec9 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -638,7 +638,9 @@ class TFTrainer: reduced_features = { k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items() } - reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] + reduced_labels = { + k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items() + } self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch) @@ -650,9 +652,13 @@ class TFTrainer: for k, ft in features.items() } - labels = tf.concat( - [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0 - ) + labels = { + k: tf.concat( + [lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]], + axis=0, + ) + for k, lbl in labels.items() + } gradients = self.gradient_accumulator.gradients gradients = [