Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels (#15234)
* Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels * Add test for numpy scalar inputs
This commit is contained in:
@@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
|
||||
|
||||
def test_numpy_dtype_preservation(self):
|
||||
data_collator = default_data_collator
|
||||
|
||||
# Confirms that numpy inputs are handled correctly even when scalars
|
||||
features = [{"input_ids": np.array([0, 1, 2, 3, 4]), "label": np.int64(i)} for i in range(4)]
|
||||
batch = data_collator(features, return_tensors="tf")
|
||||
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||
|
||||
def test_default_classification_and_regression(self):
|
||||
data_collator = default_data_collator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user