Data collator for token classification pads labels column when receives pytorch tensors (#20244)
* token cls data_collator pads labels column * remove walrus operator for code quality * remove redundat space * remove comment that was fixed * PR comments fix Co-authored-by: Alexander Markov <amarkov.me@gmail.com>
This commit is contained in:
@@ -154,6 +154,51 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
||||
|
||||
for feature in features:
|
||||
feature.pop("labels")
|
||||
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||
|
||||
def test_data_collator_for_token_classification_works_with_pt_tensors(self):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
features = [
|
||||
{"input_ids": torch.tensor([0, 1, 2]), "labels": torch.tensor([0, 1, 2])},
|
||||
{"input_ids": torch.tensor([0, 1, 2, 3, 4, 5]), "labels": torch.tensor([0, 1, 2, 3, 4, 5])},
|
||||
]
|
||||
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer)
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-100] * 3)
|
||||
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer, padding="max_length", max_length=10)
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 10]))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size([2, 10]))
|
||||
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 8]))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size([2, 8]))
|
||||
|
||||
data_collator = DataCollatorForTokenClassification(tokenizer, label_pad_token_id=-1)
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||
self.assertEqual(batch["labels"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["labels"][0].tolist(), [0, 1, 2] + [-1] * 3)
|
||||
|
||||
for feature in features:
|
||||
feature.pop("labels")
|
||||
|
||||
batch = data_collator(features)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size([2, 6]))
|
||||
self.assertEqual(batch["input_ids"][0].tolist(), [0, 1, 2] + [tokenizer.pad_token_id] * 3)
|
||||
|
||||
def _test_no_pad_and_pad(self, no_pad_features, pad_features):
|
||||
tokenizer = BertTokenizer(self.vocab_file)
|
||||
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
|
||||
|
||||
Reference in New Issue
Block a user