From 610acc5ae947eda71933de642a8b86196252bc43 Mon Sep 17 00:00:00 2001 From: Alexander Markov Date: Wed, 16 Nov 2022 21:18:46 +0400 Subject: [PATCH] Data collator for token classification pads labels column when receives pytorch tensors (#20244) * token cls data_collator pads labels column * remove walrus operator for code quality * remove redundat space * remove comment that was fixed * PR comments fix Co-authored-by: Alexander Markov --- src/transformers/data/data_collator.py | 22 +++++++++---- tests/trainer/test_data_collator.py | 45 ++++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 7 deletions(-) diff --git a/src/transformers/data/data_collator.py b/src/transformers/data/data_collator.py index c06c569692..60522344d4 100644 --- a/src/transformers/data/data_collator.py +++ b/src/transformers/data/data_collator.py @@ -305,30 +305,38 @@ class DataCollatorForTokenClassification(DataCollatorMixin): label_name = "label" if "label" in features[0].keys() else "labels" labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None + + no_labels_features = [{k: v for k, v in feature.items() if k != label_name} for feature in features] + batch = self.tokenizer.pad( - features, + no_labels_features, padding=self.padding, max_length=self.max_length, pad_to_multiple_of=self.pad_to_multiple_of, - # Conversion to tensors will fail if we have labels as they are not of the same length yet. - return_tensors="pt" if labels is None else None, + return_tensors="pt", ) if labels is None: return batch - sequence_length = torch.tensor(batch["input_ids"]).shape[1] + sequence_length = batch["input_ids"].shape[1] padding_side = self.tokenizer.padding_side + + def to_list(tensor_or_iterable): + if isinstance(tensor_or_iterable, torch.Tensor): + return tensor_or_iterable.tolist() + return list(tensor_or_iterable) + if padding_side == "right": batch[label_name] = [ - list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels + to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels ] else: batch[label_name] = [ - [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels + [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels ] - batch = {k: torch.tensor(v, dtype=torch.int64) for k, v in batch.items()} + batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) return batch def tf_call(self, features): diff --git a/tests/trainer/test_data_collator.py b/tests/trainer/test_data_collator.py index bd610873c1..39277ca8cc 100644 --- a/tests/trainer/test_data_collator.py +++ b/tests/trainer/test_data_collator.py @@ -154,6 +154,51 @@ class DataCollatorIntegrationTest(unittest.TestCase): self.assertEqual(batch["labels"].shape, torch.Size([2, 6])) self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3) + for feature in features: + feature.pop("labels") + + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6])) + self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3) + + def test_data_collator_for_token_classification_works_with_pt_tensors(self): + tokenizer = BertTokenizer(self.vocab_file) + features = [ + {"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([0, 1, 2])}, + {"input_ids": torch.tensor([0, 1, 2, 3, 4, 5]), "labels": torch.tensor([0, 1, 2, 3, 4, 5])}, + ] + + data_collator = DataCollatorForTokenClassification(tokenizer) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6])) + self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3) + self.assertEqual(batch["labels"].shape, torch.Size([2, 6])) + self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3) + + data_collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10])) + self.assertEqual(batch["labels"].shape, torch.Size([2, 10])) + + data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8])) + self.assertEqual(batch["labels"].shape, torch.Size([2, 8])) + + data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1) + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6])) + self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3) + self.assertEqual(batch["labels"].shape, torch.Size([2, 6])) + self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3) + + for feature in features: + feature.pop("labels") + + batch = data_collator(features) + self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6])) + self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3) + def _test_no_pad_and_pad(self, no_pad_features, pad_features): tokenizer = BertTokenizer(self.vocab_file) data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)