Trainer support for IterableDataset for evaluation and predict (#11286)
* Bulk of the work * Polish and tests * Update QA Trainer * Avoid breaking the predict method * Deprecation warnings * Store real eval dataloder * Get eval dataset reference before wrap
This commit is contained in:
@@ -819,35 +819,64 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
)
|
||||
self.assertEqual(len(dataset), 31)
|
||||
|
||||
def test_trainer_iterable_dataset(self):
|
||||
def test_training_iterable_dataset(self):
|
||||
config = RegressionModelConfig()
|
||||
model = RegressionPreTrainedModel(config)
|
||||
train_dataset = SampleIterableDataset()
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=2)
|
||||
args = RegressionTrainingArguments(output_dir="./examples", max_steps=4)
|
||||
trainer = Trainer(model=model, args=args, train_dataset=train_dataset)
|
||||
trainer.train()
|
||||
self.assertEqual(trainer.state.global_step, 4)
|
||||
|
||||
loader = trainer.get_train_dataloader()
|
||||
self.assertIsInstance(loader, torch.utils.data.DataLoader)
|
||||
self.assertIsInstance(loader.sampler, torch.utils.data.dataloader._InfiniteConstantSampler)
|
||||
|
||||
# Exception if giving iterable dataset and no max_steps
|
||||
with self.assertRaises(ValueError):
|
||||
args1 = RegressionTrainingArguments(output_dir="./examples")
|
||||
_ = Trainer(model=model, args=args1, train_dataset=train_dataset)
|
||||
def test_evaluation_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
# Exception if eval_dataset is iterable in __init__
|
||||
with self.assertRaises(ValueError):
|
||||
_ = Trainer(model=model, args=args, train_dataset=train_dataset, eval_dataset=train_dataset)
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
results = trainer.evaluate()
|
||||
|
||||
# Exception if predicting with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.predict(train_dataset)
|
||||
x, y = trainer.eval_dataset.dataset.x, trainer.eval_dataset.dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
# Exception if evaluating with iterable dataset
|
||||
with self.assertRaises(ValueError):
|
||||
trainer.evaluate(train_dataset)
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
eval_dataset = SampleIterableDataset(length=66)
|
||||
results = trainer.evaluate(eval_dataset)
|
||||
|
||||
x, y = eval_dataset.dataset.x, eval_dataset.dataset.ys[0]
|
||||
pred = 1.5 * x + 2.5
|
||||
expected_loss = ((pred - y) ** 2).mean()
|
||||
self.assertAlmostEqual(results["eval_loss"], expected_loss)
|
||||
expected_acc = AlmostAccuracy()((pred, y))["accuracy"]
|
||||
self.assertAlmostEqual(results["eval_accuracy"], expected_acc)
|
||||
|
||||
def test_predict_iterable_dataset(self):
|
||||
config = RegressionModelConfig(a=1.5, b=2.5)
|
||||
model = RegressionPreTrainedModel(config)
|
||||
eval_dataset = SampleIterableDataset()
|
||||
|
||||
args = RegressionTrainingArguments(output_dir="./examples")
|
||||
trainer = Trainer(model=model, args=args, eval_dataset=eval_dataset, compute_metrics=AlmostAccuracy())
|
||||
|
||||
preds = trainer.predict(trainer.eval_dataset).predictions
|
||||
x = eval_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
# With a number of elements not a round multiple of the batch size
|
||||
test_dataset = SampleIterableDataset(length=66)
|
||||
preds = trainer.predict(test_dataset).predictions
|
||||
x = test_dataset.dataset.x
|
||||
self.assertTrue(np.allclose(preds, 1.5 * x + 2.5))
|
||||
|
||||
def test_num_train_epochs_in_training(self):
|
||||
# len(train_dl) < gradient_accumulation_steps shouldn't give ``ZeroDivisionError`` when ``max_steps`` is given.
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -34,6 +35,7 @@ if is_torch_available():
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
SequentialDistributedSampler,
|
||||
ShardSampler,
|
||||
get_parameter_names,
|
||||
)
|
||||
|
||||
@@ -283,6 +285,10 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
# All shards have the same number of samples
|
||||
self.assertEqual(len(shard), len(shard_lists[0]))
|
||||
|
||||
for shard in shards:
|
||||
# All shards know the total number of samples
|
||||
self.assertEqual(shard.num_examples, len(reference))
|
||||
|
||||
observed = []
|
||||
for idx in range(0, len(shard_lists[0]), batch_size):
|
||||
for shard in shard_lists:
|
||||
@@ -295,11 +301,62 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
reference += reference
|
||||
self.assertListEqual(observed, reference[: len(observed)])
|
||||
|
||||
# Check equivalence between IterableDataset and ShardSampler
|
||||
dataset.generator.manual_seed(epoch)
|
||||
reference = list(dataset)
|
||||
|
||||
sampler_shards = [
|
||||
ShardSampler(
|
||||
reference, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
|
||||
)
|
||||
for i in range(num_processes)
|
||||
]
|
||||
for shard, sampler_shard in zip(shard_lists, sampler_shards):
|
||||
self.assertListEqual(shard, list(sampler_shard))
|
||||
|
||||
def test_iterable_dataset_shard(self):
|
||||
dataset = RandomIterableDataset()
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=2, epoch=0)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=2, epoch=0)
|
||||
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=True, num_processes=3, epoch=42)
|
||||
self.check_iterable_dataset_shard(dataset, 4, drop_last=False, num_processes=3, epoch=42)
|
||||
|
||||
def check_shard_sampler(self, dataset, batch_size, drop_last, num_processes=2):
|
||||
shards = [
|
||||
ShardSampler(
|
||||
dataset, batch_size=batch_size, drop_last=drop_last, num_processes=num_processes, process_index=i
|
||||
)
|
||||
for i in range(num_processes)
|
||||
]
|
||||
shard_lists = [list(shard) for shard in shards]
|
||||
|
||||
for shard in shard_lists:
|
||||
# All shards have a number of samples that is a round multiple of batch size
|
||||
self.assertTrue(len(shard) % batch_size == 0)
|
||||
# All shards have the same number of samples
|
||||
self.assertEqual(len(shard), len(shard_lists[0]))
|
||||
|
||||
observed = []
|
||||
for idx in range(0, len(shard_lists[0]), batch_size):
|
||||
for shard in shard_lists:
|
||||
observed += shard[idx : idx + batch_size]
|
||||
|
||||
# If drop_last is False we loop through samples at the beginning to have a size that is a round multiple of
|
||||
# batch_size
|
||||
reference = copy.copy(dataset)
|
||||
if not drop_last:
|
||||
while len(reference) < len(observed):
|
||||
reference += reference
|
||||
self.assertListEqual(observed, reference[: len(observed)])
|
||||
|
||||
def test_shard_sampler(self):
|
||||
for n_elements in [64, 123]:
|
||||
dataset = list(range(n_elements))
|
||||
|
||||
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=2)
|
||||
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=2)
|
||||
|
||||
self.check_shard_sampler(dataset, 4, drop_last=True, num_processes=3)
|
||||
self.check_shard_sampler(dataset, 4, drop_last=False, num_processes=3)
|
||||
|
||||
Reference in New Issue
Block a user