Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels (#15234)
* Fixes tf_default_data_collator sometimes guessing the wrong dtype for labels * Add test for numpy scalar inputs
This commit is contained in:
@@ -145,26 +145,27 @@ def tf_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
|
|||||||
# Ensure that tensor is created with the correct type
|
# Ensure that tensor is created with the correct type
|
||||||
# (it should be automatically the case, but let's make sure of it.)
|
# (it should be automatically the case, but let's make sure of it.)
|
||||||
if "label" in first and first["label"] is not None:
|
if "label" in first and first["label"] is not None:
|
||||||
if isinstance(first["label"], tf.Tensor):
|
label_col_name = "label"
|
||||||
dtype = tf.int64 if first["label"].dtype.is_integer() else tf.float32
|
|
||||||
elif isinstance(first["label"], np.ndarray):
|
|
||||||
dtype = tf.int64 if np.issubdtype(first["label"].dtype, np.integer) else tf.float32
|
|
||||||
elif isinstance(first["label"], (tuple, list)):
|
|
||||||
dtype = tf.int64 if isinstance(first["label"][0], int) else tf.float32
|
|
||||||
else:
|
|
||||||
dtype = tf.int64 if isinstance(first["label"], int) else tf.float32
|
|
||||||
batch["labels"] = tf.convert_to_tensor([f["label"] for f in features], dtype=dtype)
|
|
||||||
elif "label_ids" in first and first["label_ids"] is not None:
|
elif "label_ids" in first and first["label_ids"] is not None:
|
||||||
if isinstance(first["label_ids"], tf.Tensor):
|
label_col_name = "label_ids"
|
||||||
batch["labels"] = tf.stack([f["label_ids"] for f in features])
|
elif "labels" in first and first["labels"] is not None:
|
||||||
|
label_col_name = "labels"
|
||||||
else:
|
else:
|
||||||
dtype = tf.int64 if type(first["label_ids"][0]) is int else tf.float32
|
label_col_name = None
|
||||||
batch["labels"] = tf.convert_to_tensor([f["label_ids"] for f in features], dtype=dtype)
|
if label_col_name is not None:
|
||||||
|
if isinstance(first[label_col_name], tf.Tensor):
|
||||||
|
dtype = tf.int64 if first[label_col_name].dtype.is_integer() else tf.float32
|
||||||
|
elif isinstance(first[label_col_name], np.ndarray) or isinstance(first[label_col_name], np.generic):
|
||||||
|
dtype = tf.int64 if np.issubdtype(first[label_col_name].dtype, np.integer) else tf.float32
|
||||||
|
elif isinstance(first[label_col_name], (tuple, list)):
|
||||||
|
dtype = tf.int64 if isinstance(first[label_col_name][0], int) else tf.float32
|
||||||
|
else:
|
||||||
|
dtype = tf.int64 if isinstance(first[label_col_name], int) else tf.float32
|
||||||
|
batch["labels"] = tf.convert_to_tensor([f[label_col_name] for f in features], dtype=dtype)
|
||||||
# Handling of all other possible keys.
|
# Handling of all other possible keys.
|
||||||
# Again, we will use the first element to figure out which key/values are not None for this model.
|
# Again, we will use the first element to figure out which key/values are not None for this model.
|
||||||
for k, v in first.items():
|
for k, v in first.items():
|
||||||
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
|
if k not in ("label", "label_ids", "labels") and v is not None and not isinstance(v, str):
|
||||||
if isinstance(v, (tf.Tensor, np.ndarray)):
|
if isinstance(v, (tf.Tensor, np.ndarray)):
|
||||||
batch[k] = tf.stack([f[k] for f in features])
|
batch[k] = tf.stack([f[k] for f in features])
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -353,6 +353,14 @@ class TFDataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["labels"].dtype, tf.int64)
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
|
self.assertEqual(batch["inputs"].shape.as_list(), [8, 10])
|
||||||
|
|
||||||
|
def test_numpy_dtype_preservation(self):
|
||||||
|
data_collator = default_data_collator
|
||||||
|
|
||||||
|
# Confirms that numpy inputs are handled correctly even when scalars
|
||||||
|
features = [{"input_ids": np.array([0, 1, 2, 3, 4]), "label": np.int64(i)} for i in range(4)]
|
||||||
|
batch = data_collator(features, return_tensors="tf")
|
||||||
|
self.assertEqual(batch["labels"].dtype, tf.int64)
|
||||||
|
|
||||||
def test_default_classification_and_regression(self):
|
def test_default_classification_and_regression(self):
|
||||||
data_collator = default_data_collator
|
data_collator = default_data_collator
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user