Trainer support for iterabledataset (#5834)
* Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Don't pass sampler for iterable dataset * Added check for test and eval dataloaders. * Formatting * Cleaner if nesting. * Added test for trainer and iterable dataset * Formatting for test * Fixed import when torch is available only. * Added require torch decorator to helper class * Moved dataset class inside unittest * Removed nested if and changed model in test * Checking torch availability for IterableDataset
This commit is contained in:
31
tests/test_trainer.py
Normal file → Executable file
31
tests/test_trainer.py
Normal file → Executable file
@@ -6,16 +6,18 @@ from transformers.testing_utils import require_torch
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.utils.data import IterableDataset
|
||||
|
||||
from transformers import (
|
||||
Trainer,
|
||||
LineByLineTextDataset,
|
||||
AutoModelForSequenceClassification,
|
||||
default_data_collator,
|
||||
DataCollatorForLanguageModeling,
|
||||
DataCollatorForPermutationLanguageModeling,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
LineByLineTextDataset,
|
||||
TextDataset,
|
||||
Trainer,
|
||||
default_data_collator,
|
||||
)
|
||||
|
||||
|
||||
@@ -153,6 +155,20 @@ class DataCollatorIntegrationTest(unittest.TestCase):
|
||||
data_collator(example)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
|
||||
class SampleIterableDataset(IterableDataset):
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
|
||||
def parse_file(self):
|
||||
f = open(self.file_path, "r")
|
||||
return f.readlines()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.parse_file())
|
||||
|
||||
|
||||
@require_torch
|
||||
class TrainerIntegrationTest(unittest.TestCase):
|
||||
def test_trainer_eval_mrpc(self):
|
||||
@@ -176,3 +192,12 @@ class TrainerIntegrationTest(unittest.TestCase):
|
||||
tokenizer=tokenizer, file_path=PATH_SAMPLE_TEXT, block_size=tokenizer.max_len_single_sentence,
|
||||
)
|
||||
self.assertEqual(len(dataset), 31)
|
||||
|
||||
def test_trainer_iterable_dataset(self):
|
||||
MODEL_ID = "sshleifer/tiny-distilbert-base-cased"
|
||||
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID)
|
||||
train_dataset = SampleIterableDataset(PATH_SAMPLE_TEXT)
|
||||
training_args = TrainingArguments(output_dir="./examples", no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, train_dataset=train_dataset)
|
||||
loader = trainer.get_train_dataloader()
|
||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
|
||||
Reference in New Issue
Block a user