Make DataCollator a callable (#5015)

* Make DataCollator a callable

* Update src/transformers/data/data_collator.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-06-15 11:58:33 -04:00
committed by GitHub
parent f7c93b3cee
commit 1affde2f10
7 changed files with 60 additions and 83 deletions

View File

@@ -11,7 +11,7 @@ if is_torch_available():
Trainer,
LineByLineTextDataset,
AutoModelForSequenceClassification,
DefaultDataCollator,
default_data_collator,
DataCollatorForLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.long)
def test_default_regression(self):
@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
)
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
data_collator = DefaultDataCollator()
batch = data_collator.collate_batch(dataset.features)
data_collator = default_data_collator
batch = data_collator(dataset.features)
self.assertEqual(batch["labels"].dtype, torch.float)
def test_lm_tokenizer_without_padding(self):
@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
examples = [dataset[i] for i in range(len(dataset))]
with self.assertRaises(ValueError):
# Expect error due to padding token missing on gpt2:
data_collator.collate_batch(examples)
data_collator(examples)
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
examples = [dataset[i] for i in range(len(dataset))]
batch = data_collator.collate_batch(examples)
batch = data_collator(examples)
self.assertIsInstance(batch, dict)
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))