Trainer support for iterabledataset (#5834)

* Don't pass sampler for iterable dataset

* Added check for test and eval dataloaders.

* Formatting

* Don't pass sampler for iterable dataset

* Added check for test and eval dataloaders.

* Formatting

* Cleaner if nesting.

* Added test for trainer and iterable dataset

* Formatting for test

* Fixed import when torch is available only.

* Added require torch decorator to helper class

* Moved dataset class inside unittest

* Removed nested if and changed model in test

* Checking torch availability for IterableDataset
This commit is contained in:
Pradhy729
2020-07-20 06:07:37 -07:00
committed by GitHub
parent 82dd96cae7
commit 290b6e18ac
2 changed files with 38 additions and 9 deletions

31
tests/test_trainer.py Normal file → Executable file
View File

@@ -6,16 +6,18 @@ from transformers.testing_utils import require_torch
if is_torch_available():
import torch
from torch.utils.data import IterableDataset
from transformers import (
Trainer,
LineByLineTextDataset,
AutoModelForSequenceClassification,
default_data_collator,
DataCollatorForLanguageModeling,
DataCollatorForPermutationLanguageModeling,
GlueDataset,
GlueDataTrainingArguments,
LineByLineTextDataset,
TextDataset,
Trainer,
default_data_collator,
)
@@ -153,6 +155,20 @@ class DataCollatorIntegrationTest(unittest.TestCase):
data_collator(example)
if is_torch_available():
class SampleIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
def parse_file(self):
f = open(self.file_path, "r")
return f.readlines()
def __iter__(self):
return iter(self.parse_file())
@require_torch
class TrainerIntegrationTest(unittest.TestCase):
def test_trainer_eval_mrpc(self):
@@ -176,3 +192,12 @@ class TrainerIntegrationTest(unittest.TestCase):
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
)
self.assertEqual(len(dataset), 31)
def test_trainer_iterable_dataset(self):
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
loader = trainer.get_train_dataloader()
self.assertIsInstance(loader, torch.utils.data.DataLoader)