From 85788bae5cbdcc5035e128c3a7879a082ee4af5c Mon Sep 17 00:00:00 2001 From: Kiyoung Kim Date: Fri, 15 Jan 2021 00:16:39 +0900 Subject: [PATCH] Revert "Gradient accumulation for TFTrainer (#9585)" This reverts commit 3f40070c88de07169fe18b0b4c4003ef2858a284. --- src/transformers/trainer_tf.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index e566082ec9..ac75eb6223 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -638,9 +638,7 @@ class TFTrainer: reduced_features = { k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items() } - reduced_labels = { - k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items() - } + reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch) @@ -652,13 +650,9 @@ class TFTrainer: for k, ft in features.items() } - labels = { - k: tf.concat( - [lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]], - axis=0, - ) - for k, lbl in labels.items() - } + labels = tf.concat( + [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0 + ) gradients = self.gradient_accumulator.gradients gradients = [