Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels (#15234)

* Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels

* Add test for numpy scalar inputs
This commit is contained in:
Matt
2022-01-20 14:26:51 +00:00
committed by GitHub
parent 4a6a35bc65
commit f00f22a3e2
2 changed files with 24 additions and 15 deletions

View File

@@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
self.assertEqual(batch["labels"].dtype, tf.int64)
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
def test_numpy_dtype_preservation(self):
data_collator = default_data_collator
# Confirms that numpy inputs are handled correctly even when scalars
features = [{"input_ids": np.array([0, 1, 2, 3, 4]), "label": np.int64(i)} for i in range(4)]
batch = data_collator(features, return_tensors="tf")
self.assertEqual(batch["labels"].dtype, tf.int64)
def test_default_classification_and_regression(self):
data_collator = default_data_collator