Make DataCollator a callable (#5015)
* Make DataCollator a callable * Update src/transformers/data/data_collator.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -11,7 +11,7 @@ if is_torch_available():
|
||||
Trainer,
|
||||
LineByLineTextDataset,
|
||||
AutoModelForSequenceClassification,
|
||||
DefaultDataCollator,
|
||||
default_data_collator,
|
||||
DataCollatorForLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
@@ -31,8 +31,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
task_name="mrpc", data_dir="./tests/fixtures/tests_samples/MRPC", overwrite_cache=True
|
||||
)
|
||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||
data_collator = DefaultDataCollator()
|
||||
batch = data_collator.collate_batch(dataset.features)
|
||||
data_collator = default_data_collator
|
||||
batch = data_collator(dataset.features)
|
||||
self.assertEqual(batch["labels"].dtype, torch.long)
|
||||
|
||||
def test_default_regression(self):
|
||||
@@ -42,8 +42,8 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
task_name="sts-b", data_dir="./tests/fixtures/tests_samples/STS-B", overwrite_cache=True
|
||||
)
|
||||
dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev")
|
||||
data_collator = DefaultDataCollator()
|
||||
batch = data_collator.collate_batch(dataset.features)
|
||||
data_collator = default_data_collator
|
||||
batch = data_collator(dataset.features)
|
||||
self.assertEqual(batch["labels"].dtype, torch.float)
|
||||
|
||||
def test_lm_tokenizer_without_padding(self):
|
||||
@@ -55,11 +55,11 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
with self.assertRaises(ValueError):
|
||||
# Expect error due to padding token missing on gpt2:
|
||||
data_collator.collate_batch(examples)
|
||||
data_collator(examples)
|
||||
|
||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator.collate_batch(examples)
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
@@ -71,14 +71,14 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
|
||||
dataset = LineByLineTextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator.collate_batch(examples)
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((31, 107)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((31, 107)))
|
||||
|
||||
dataset = TextDataset(tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=512, overwrite_cache=True)
|
||||
examples = [dataset[i] for i in range(len(dataset))]
|
||||
batch = data_collator.collate_batch(examples)
|
||||
batch = data_collator(examples)
|
||||
self.assertIsInstance(batch, dict)
|
||||
self.assertEqual(batch["input_ids"].shape, torch.Size((2, 512)))
|
||||
self.assertEqual(batch["labels"].shape, torch.Size((2, 512)))
|
||||
|
||||
@@ -29,7 +29,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from transformers import DataCollator, Trainer
|
||||
from transformers import Trainer
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length: int = 101):
|
||||
@@ -41,8 +41,8 @@ if is_torch_available():
|
||||
def __getitem__(self, i) -> int:
|
||||
return i
|
||||
|
||||
class DummyDataCollator(DataCollator):
|
||||
def collate_batch(self, features):
|
||||
class DummyDataCollator:
|
||||
def __call__(self, features):
|
||||
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user