Immutability for data collators (#30603)
* immutability fix for seq2seq as well as immutability tests for the collators * ensure we don't act on none labels and formatting * remove tf/pt in respective tests as they are not required * more type error fixes tf/np * remove todo * apply suggestions from code review * formatting / style
This commit is contained in:
@@ -585,51 +585,84 @@ class DataCollatorForSeq2Seq:
|
|||||||
def __call__(self, features, return_tensors=None):
|
def __call__(self, features, return_tensors=None):
|
||||||
if return_tensors is None:
|
if return_tensors is None:
|
||||||
return_tensors = self.return_tensors
|
return_tensors = self.return_tensors
|
||||||
labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
|
|
||||||
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
|
||||||
# same length to return tensors.
|
|
||||||
no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
|
|
||||||
if labels is not None and not no_padding:
|
|
||||||
max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
|
|
||||||
max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
|
|
||||||
if self.pad_to_multiple_of is not None:
|
|
||||||
max_label_length = (
|
|
||||||
(max_label_length + self.pad_to_multiple_of - 1)
|
|
||||||
// self.pad_to_multiple_of
|
|
||||||
* self.pad_to_multiple_of
|
|
||||||
)
|
|
||||||
|
|
||||||
padding_side = self.tokenizer.padding_side
|
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||||
for feature in features:
|
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
||||||
remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
|
# reconvert list[None] to None if necessary
|
||||||
if isinstance(feature["labels"], list):
|
# this might occur when we pass {..., "labels": None}
|
||||||
feature["labels"] = (
|
if labels is not None and all(label is None for label in labels):
|
||||||
feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
|
labels = None
|
||||||
)
|
non_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
|
||||||
elif padding_side == "right":
|
|
||||||
feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
|
|
||||||
else:
|
|
||||||
feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
|
|
||||||
|
|
||||||
features = pad_without_fast_tokenizer_warning(
|
# run through tokenizer without labels to ensure no side effects
|
||||||
|
batch = pad_without_fast_tokenizer_warning(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
features,
|
non_labels_features,
|
||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# we have to pad the labels manually as we cannot rely on `tokenizer.pad` and we need them to be of the same length to return tensors
|
||||||
|
no_padding = self.padding is False or self.padding == PaddingStrategy.DO_NOT_PAD
|
||||||
|
if labels is not None:
|
||||||
|
if no_padding:
|
||||||
|
if isinstance(features[0][label_name], list):
|
||||||
|
batch["labels"] = list(labels)
|
||||||
|
else:
|
||||||
|
batch["labels"] = [np.concatenate([label, []]) for label in labels]
|
||||||
|
else:
|
||||||
|
max_padding = self.padding == PaddingStrategy.MAX_LENGTH and self.max_length is not None
|
||||||
|
max_label_length = max(len(l) for l in labels) if not max_padding else self.max_length
|
||||||
|
if self.pad_to_multiple_of is not None:
|
||||||
|
max_label_length = (
|
||||||
|
(max_label_length + self.pad_to_multiple_of - 1)
|
||||||
|
// self.pad_to_multiple_of
|
||||||
|
* self.pad_to_multiple_of
|
||||||
|
)
|
||||||
|
|
||||||
|
padding_side = self.tokenizer.padding_side
|
||||||
|
if isinstance(features[0][label_name], list):
|
||||||
|
batch["labels"] = [
|
||||||
|
label + [self.label_pad_token_id] * (max_label_length - len(label))
|
||||||
|
if padding_side == "right"
|
||||||
|
else [self.label_pad_token_id] * (max_label_length - len(label)) + label
|
||||||
|
for label in labels
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
batch["labels"] = [
|
||||||
|
np.concatenate([label, [self.label_pad_token_id] * (max_label_length - len(label))])
|
||||||
|
if padding_side == "right"
|
||||||
|
else np.concatenate([[self.label_pad_token_id] * (max_label_length - len(label)), label])
|
||||||
|
for label in labels
|
||||||
|
]
|
||||||
|
|
||||||
|
# reintroduce side effects via tokenizer that return respective datatypes for the `return_tensors` argument
|
||||||
|
if batch.get("labels", None) is not None:
|
||||||
|
if return_tensors == "pt":
|
||||||
|
import torch
|
||||||
|
|
||||||
|
batch["labels"] = torch.tensor(batch["labels"], dtype=torch.int64)
|
||||||
|
elif return_tensors == "tf":
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
batch["labels"] = tf.constant(batch["labels"], dtype=tf.int64)
|
||||||
|
else:
|
||||||
|
batch["labels"] = np.array(batch["labels"], dtype=np.int64)
|
||||||
|
else:
|
||||||
|
batch["labels"] = None
|
||||||
|
|
||||||
# prepare decoder_input_ids
|
# prepare decoder_input_ids
|
||||||
if (
|
if (
|
||||||
labels is not None
|
labels is not None
|
||||||
and self.model is not None
|
and self.model is not None
|
||||||
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||||
):
|
):
|
||||||
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
|
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
|
||||||
features["decoder_input_ids"] = decoder_input_ids
|
batch["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
return features
|
return batch
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user