Make DataCollator a callable (#5015)
* Make DataCollator a callable * Update src/transformers/data/data_collator.py Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
@@ -29,7 +29,7 @@ if is_torch_available():
|
||||
from torch import nn
|
||||
from torch.utils.data.dataset import Dataset
|
||||
|
||||
from transformers import DataCollator, Trainer
|
||||
from transformers import Trainer
|
||||
|
||||
class DummyDataset(Dataset):
|
||||
def __init__(self, length: int = 101):
|
||||
@@ -41,8 +41,8 @@ if is_torch_available():
|
||||
def __getitem__(self, i) -> int:
|
||||
return i
|
||||
|
||||
class DummyDataCollator(DataCollator):
|
||||
def collate_batch(self, features):
|
||||
class DummyDataCollator:
|
||||
def __call__(self, features):
|
||||
return {"input_ids": torch.tensor(features), "labels": torch.tensor(features)}
|
||||
|
||||
class DummyModel(nn.Module):
|
||||
|
||||
Reference in New Issue
Block a user