From f00f22a3e290fd377b979124dcf9800b3d73eb11 Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 20 Jan 2022 14:26:51 +0000 Subject: [PATCH] Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels (#15234) * Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels * Add test for numpy scalar inputs --- src/transformers/data/data_collator.py | 31 +++++++++++++------------- tests/test_data_collator.py | 8 +++++++ 2 files changed, 24 insertions(+), 15 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index 0b087b483d..a2d9fa5a0c 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -145,26 +145,27 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]: # Ensure that tensor is created with the correct type # (it should be automatically the case, but let's make sure of it.) if "label" in first and first["label"] is not None: - if isinstance(first["label"], tf.Tensor): - dtype = tf.int64 if first["label"].dtype.is_integer() else tf.float32 - elif isinstance(first["label"], np.ndarray): - dtype = tf.int64 if np.issubdtype(first["label"].dtype, np.integer) else tf.float32 - elif isinstance(first["label"], (tuple, list)): - dtype = tf.int64 if isinstance(first["label"][0], int) else tf.float32 - else: - dtype = tf.int64 if isinstance(first["label"], int) else tf.float32 - batch["labels"] = tf.convert_to_tensor([f["label"] for f in features], dtype=dtype) + label_col_name = "label" elif "label_ids" in first and first["label_ids"] is not None: - if isinstance(first["label_ids"], tf.Tensor): - batch["labels"] = tf.stack([f["label_ids"] for f in features]) + label_col_name = "label_ids" + elif "labels" in first and first["labels"] is not None: + label_col_name = "labels" + else: + label_col_name = None + if label_col_name is not None: + if isinstance(first[label_col_name], tf.Tensor): + dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32 + elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic): + dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32 + elif isinstance(first[label_col_name], (tuple, list)): + dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32 else: - dtype = tf.int64 if type(first["label_ids"][0]) is int else tf.float32 - batch["labels"] = tf.convert_to_tensor([f["label_ids"] for f in features], dtype=dtype) - + dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32 + batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype) # Handling of all other possible keys. # Again, we will use the first element to figure out which key/values are not None for this model. for k, v in first.items(): - if k not in ("label", "label_ids") and v is not None and not isinstance(v, str): + if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str): if isinstance(v, (tf.Tensor, np.ndarray)): batch[k] = tf.stack([f[k] for f in features]) else: diff --git a/tests/test_data_collator.py b/tests/test_data_collator.py index d9bcc08447..bd610873c1 100644 --- a/tests/test_data_collator.py +++ b/tests/test_data_collator.py @@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["labels"].dtype, tf.int64) self.assertEqual(batch["inputs"].shape.as_list(), [8, 10]) + def test_numpy_dtype_preservation(self): + data_collator = default_data_collator + + # Confirms that numpy inputs are handled correctly even when scalars + features = [{"input_ids": np.array([0, 1, 2, 3, 4]), "label": np.int64(i)} for i in range(4)] + batch = data_collator(features, return_tensors="tf") + self.assertEqual(batch["labels"].dtype, tf.int64) + def test_default_classification_and_regression(self): data_collator = default_data_collator