Data collator for token classification pads labels column when receives pytorch tensors (#20244)
* token cls data_collator pads labels column * remove walrus operator for code quality * remove redundat space * remove comment that was fixed * PR comments fix Co-authored-by: Alexander Markov <amarkov.me@gmail.com>
This commit is contained in:
@@ -305,30 +305,38 @@ class DataCollatorForTokenClassification(DataCollatorMixin):
|
|||||||
|
|
||||||
label_name = "label" if "label" in features[0].keys() else "labels"
|
label_name = "label" if "label" in features[0].keys() else "labels"
|
||||||
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
||||||
|
|
||||||
|
no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features]
|
||||||
|
|
||||||
batch = self.tokenizer.pad(
|
batch = self.tokenizer.pad(
|
||||||
features,
|
no_labels_features,
|
||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
max_length=self.max_length,
|
max_length=self.max_length,
|
||||||
pad_to_multiple_of=self.pad_to_multiple_of,
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
# Conversion to tensors will fail if we have labels as they are not of the same length yet.
|
return_tensors="pt",
|
||||||
return_tensors="pt" if labels is None else None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if labels is None:
|
if labels is None:
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
sequence_length = torch.tensor(batch["input_ids"]).shape[1]
|
sequence_length = batch["input_ids"].shape[1]
|
||||||
padding_side = self.tokenizer.padding_side
|
padding_side = self.tokenizer.padding_side
|
||||||
|
|
||||||
|
def to_list(tensor_or_iterable):
|
||||||
|
if isinstance(tensor_or_iterable, torch.Tensor):
|
||||||
|
return tensor_or_iterable.tolist()
|
||||||
|
return list(tensor_or_iterable)
|
||||||
|
|
||||||
if padding_side == "right":
|
if padding_side == "right":
|
||||||
batch[label_name] = [
|
batch[label_name] = [
|
||||||
list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
batch[label_name] = [
|
batch[label_name] = [
|
||||||
[self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
|
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
|
||||||
]
|
]
|
||||||
|
|
||||||
batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()}
|
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
|
||||||
return batch
|
return batch
|
||||||
|
|
||||||
def tf_call(self, features):
|
def tf_call(self, features):
|
||||||
|
|||||||
@@ -154,6 +154,51 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||||
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
feature.pop("labels")
|
||||||
|
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||||
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
|
||||||
|
def test_data_collator_for_token_classification_works_with_pt_tensors(self):
|
||||||
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
|
features = [
|
||||||
|
{"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([0, 1, 2])},
|
||||||
|
{"input_ids": torch.tensor([0, 1, 2, 3, 4, 5]), "labels": torch.tensor([0, 1, 2, 3, 4, 5])},
|
||||||
|
]
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer)
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||||
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||||
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3)
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10)
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 8]))
|
||||||
|
|
||||||
|
data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1)
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||||
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||||
|
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
||||||
|
|
||||||
|
for feature in features:
|
||||||
|
feature.pop("labels")
|
||||||
|
|
||||||
|
batch = data_collator(features)
|
||||||
|
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||||
|
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||||
|
|
||||||
def _test_no_pad_and_pad(self, no_pad_features, pad_features):
|
def _test_no_pad_and_pad(self, no_pad_features, pad_features):
|
||||||
tokenizer = BertTokenizer(self.vocab_file)
|
tokenizer = BertTokenizer(self.vocab_file)
|
||||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||||
|
|||||||
Reference in New Issue
Block a user