Make DataCollator a callable (#5015)

* Make DataCollator a callable

* Update src/transformers/data/data_collator.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Sylvain Gugger
2020-06-15 11:58:33 -04:00
committed by GitHub
parent f7c93b3cee
commit 1affde2f10
7 changed files with 60 additions and 83 deletions

View File

@@ -29,7 +29,7 @@ if is_torch_available():
from torch import nn
from torch.utils.data.dataset import Dataset
from transformers import DataCollator, Trainer
from transformers import Trainer
class DummyDataset(Dataset):
def __init__(self, length: int = 101):
@@ -41,8 +41,8 @@ if is_torch_available():
def __getitem__(self, i) -> int:
return i
class DummyDataCollator(DataCollator):
def collate_batch(self, features):
class DummyDataCollator:
def __call__(self, features):
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
class DummyModel(nn.Module):