From 12f0d7e8e0f8f92ff0585e23e1bbd5960644e4c6 Mon Sep 17 00:00:00 2001 From: Julien Plu Date: Wed, 20 Jan 2021 12:08:00 +0100 Subject: [PATCH] Fix label datatype in TF Trainer (#9616) * Fix label datatype * Apply style --- src/transformers/trainer_tf.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer_tf.py b/src/transformers/trainer_tf.py index ac75eb6223..d6c92ebc05 100644 --- a/src/transformers/trainer_tf.py +++ b/src/transformers/trainer_tf.py @@ -638,7 +638,15 @@ class TFTrainer: reduced_features = { k: ft[: self.args.train_batch_size // self.args.n_replicas] for k, ft in features.items() } - reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] + + if tf.is_tensor(labels): + reduced_labels = labels[: self.args.train_batch_size // self.args.n_replicas] + elif isinstance(labels, dict): + reduced_labels = { + k: lbl[: self.args.train_batch_size // self.args.n_replicas] for k, lbl in labels.items() + } + else: + raise ValueError("The labels must be either a tf.Tensor or a dict.") self.training_step(reduced_features, reduced_labels, nb_instances_in_global_batch) @@ -650,9 +658,20 @@ class TFTrainer: for k, ft in features.items() } - labels = tf.concat( - [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0 - ) + if tf.is_tensor(labels): + labels = tf.concat( + [labels[self.args.train_batch_size // self.args.n_replicas :], reduced_labels], axis=0 + ) + elif isinstance(labels, dict): + labels = { + k: tf.concat( + [lbl[self.args.train_batch_size // self.args.n_replicas :], reduced_labels[k]], + axis=0, + ) + for k, lbl in labels.items() + } + else: + raise ValueError("The labels must be either a tf.Tensor or a dict.") gradients = self.gradient_accumulator.gradients gradients = [