Make Trainer evaluation handle dynamic seq_length (#8336)
* Make Trainer evaluation handle dynamic seq_length * Document behavior. * Fix test * Better fix * Fixes for realsies this time * Address review comments * Without forgetting to save...
This commit is contained in:
@@ -1333,6 +1333,12 @@ class Trainer:
|
|||||||
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||||
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
If your predictions or labels have different sequence length (for instance because you're doing dynamic
|
||||||
|
padding in a token classification task) the predictions will be padded (on the right) to allow for
|
||||||
|
concatenation into one array. The padding index is -100.
|
||||||
|
|
||||||
Returns: `NamedTuple` A namedtuple with the following keys:
|
Returns: `NamedTuple` A namedtuple with the following keys:
|
||||||
|
|
||||||
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
|
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
|
||||||
@@ -1412,9 +1418,9 @@ class Trainer:
|
|||||||
losses = loss.repeat(batch_size)
|
losses = loss.repeat(batch_size)
|
||||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
||||||
if logits is not None:
|
if logits is not None:
|
||||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
|
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
|
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||||
|
|
||||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||||
|
|||||||
@@ -42,17 +42,50 @@ else:
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def nested_concat(tensors, new_tensors, dim=0):
|
def torch_pad_and_concatenate(tensor1, tensor2, padding_index=-100):
|
||||||
"Concat the `new_tensors` to `tensors` on `dim`. Works for tensors or nested list/tuples of tensors."
|
"""Concatenates `tensor1` and `tensor2` on first axis, applying padding on the second if necessary."""
|
||||||
|
if len(tensor1.shape) == 1 or tensor1.shape[1] == tensor2.shape[1]:
|
||||||
|
return torch.cat((tensor1, tensor2), dim=0)
|
||||||
|
|
||||||
|
# Let's figure out the new shape
|
||||||
|
new_shape = (tensor1.shape[0] + tensor2.shape[0], max(tensor1.shape[1], tensor2.shape[1])) + tensor1.shape[2:]
|
||||||
|
|
||||||
|
# Now let's fill the result tensor
|
||||||
|
result = tensor1.new_full(new_shape, padding_index)
|
||||||
|
result[: tensor1.shape[0], : tensor1.shape[1]] = tensor1
|
||||||
|
result[tensor1.shape[0] :, : tensor2.shape[1]] = tensor2
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def numpy_pad_and_concatenate(array1, array2, padding_index=-100):
|
||||||
|
"""Concatenates `array1` and `array2` on first axis, applying padding on the second if necessary."""
|
||||||
|
if len(array1.shape) == 1 or array1.shape[1] == array2.shape[1]:
|
||||||
|
return np.concatenate((array1, array2), dim=0)
|
||||||
|
|
||||||
|
# Let's figure out the new shape
|
||||||
|
new_shape = (array1.shape[0] + array2.shape[0], max(array1.shape[1], array2.shape[1])) + array1.shape[2:]
|
||||||
|
|
||||||
|
# Now let's fill the result tensor
|
||||||
|
result = np.full_like(array1, padding_index, shape=new_shape)
|
||||||
|
result[: array1.shape[0], : array1.shape[1]] = array1
|
||||||
|
result[array1.shape[0] :, : array2.shape[1]] = array2
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def nested_concat(tensors, new_tensors, padding_index=-100):
|
||||||
|
"""
|
||||||
|
Concat the `new_tensors` to `tensors` on the first dim and pad them on the second if needed. Works for tensors or
|
||||||
|
nested list/tuples of tensors.
|
||||||
|
"""
|
||||||
assert type(tensors) == type(
|
assert type(tensors) == type(
|
||||||
new_tensors
|
new_tensors
|
||||||
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
), f"Expected `tensors` and `new_tensors` to have the same type but found {type(tensors)} and {type(new_tensors)}."
|
||||||
if isinstance(tensors, (list, tuple)):
|
if isinstance(tensors, (list, tuple)):
|
||||||
return type(tensors)(nested_concat(t, n, dim) for t, n in zip(tensors, new_tensors))
|
return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
|
||||||
elif isinstance(tensors, torch.Tensor):
|
elif isinstance(tensors, torch.Tensor):
|
||||||
return torch.cat((tensors, new_tensors), dim=dim)
|
return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
||||||
elif isinstance(tensors, np.ndarray):
|
elif isinstance(tensors, np.ndarray):
|
||||||
return np.concatenate((tensors, new_tensors), axis=dim)
|
return numpy_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
|
raise TypeError(f"Unsupported type for concatenation: got {type(tensors)}")
|
||||||
|
|
||||||
@@ -190,11 +223,21 @@ def get_tpu_sampler(dataset: torch.utils.data.dataset.Dataset):
|
|||||||
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
return DistributedSampler(dataset, num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
|
||||||
|
|
||||||
|
|
||||||
def nested_new_like(arrays, num_samples):
|
def nested_new_like(arrays, num_samples, padding_index=-100):
|
||||||
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
|
""" Create the same nested structure as `arrays` with a first dimension always at `num_samples`."""
|
||||||
if isinstance(arrays, (list, tuple)):
|
if isinstance(arrays, (list, tuple)):
|
||||||
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
|
return type(arrays)(nested_new_like(x, num_samples) for x in arrays)
|
||||||
return np.zeros((num_samples, *arrays.shape[1:]), dtype=arrays.dtype)
|
return np.full_like(arrays, padding_index, shape=(num_samples, *arrays.shape[1:]))
|
||||||
|
|
||||||
|
|
||||||
|
def nested_expand_like(arrays, new_seq_length, padding_index=-100):
|
||||||
|
""" Expand the `arrays` so that the second dimension grows to `new_seq_length`. Uses `padding_index` for padding."""
|
||||||
|
if isinstance(arrays, (list, tuple)):
|
||||||
|
return type(arrays)(nested_expand_like(x, new_seq_length, padding_index=padding_index) for x in arrays)
|
||||||
|
|
||||||
|
result = np.full_like(arrays, padding_index, shape=(arrays.shape[0], new_seq_length) + arrays.shape[2:])
|
||||||
|
result[:, : arrays.shape[1]] = arrays
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def nested_truncate(tensors, limit):
|
def nested_truncate(tensors, limit):
|
||||||
@@ -204,6 +247,13 @@ def nested_truncate(tensors, limit):
|
|||||||
return tensors[:limit]
|
return tensors[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_first_shape(arrays):
|
||||||
|
"""Return the shape of the first array found in the nested struct `arrays`."""
|
||||||
|
if isinstance(arrays, (list, tuple)):
|
||||||
|
return _get_first_shape(arrays[0])
|
||||||
|
return arrays.shape
|
||||||
|
|
||||||
|
|
||||||
class DistributedTensorGatherer:
|
class DistributedTensorGatherer:
|
||||||
"""
|
"""
|
||||||
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
||||||
@@ -247,9 +297,11 @@ class DistributedTensorGatherer:
|
|||||||
make_multiple_of (:obj:`int`, `optional`):
|
make_multiple_of (:obj:`int`, `optional`):
|
||||||
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
|
If passed, the class assumes the datasets passed to each process are made to be a multiple of this argument
|
||||||
(by adding samples).
|
(by adding samples).
|
||||||
|
padding_index (:obj:`int`, `optional`, defaults to -100):
|
||||||
|
The padding index to use if the arrays don't all have the same sequence length.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, world_size, num_samples, make_multiple_of=None):
|
def __init__(self, world_size, num_samples, make_multiple_of=None, padding_index=-100):
|
||||||
self.world_size = world_size
|
self.world_size = world_size
|
||||||
self.num_samples = num_samples
|
self.num_samples = num_samples
|
||||||
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
|
total_size = world_size if make_multiple_of is None else world_size * make_multiple_of
|
||||||
@@ -257,6 +309,7 @@ class DistributedTensorGatherer:
|
|||||||
self.process_length = self.total_samples // world_size
|
self.process_length = self.total_samples // world_size
|
||||||
self._storage = None
|
self._storage = None
|
||||||
self._offsets = None
|
self._offsets = None
|
||||||
|
self.padding_index = padding_index
|
||||||
|
|
||||||
def add_arrays(self, arrays):
|
def add_arrays(self, arrays):
|
||||||
"""
|
"""
|
||||||
@@ -266,8 +319,14 @@ class DistributedTensorGatherer:
|
|||||||
if arrays is None:
|
if arrays is None:
|
||||||
return
|
return
|
||||||
if self._storage is None:
|
if self._storage is None:
|
||||||
self._storage = nested_new_like(arrays, self.total_samples)
|
self._storage = nested_new_like(arrays, self.total_samples, padding_index=self.padding_index)
|
||||||
self._offsets = list(range(0, self.total_samples, self.process_length))
|
self._offsets = list(range(0, self.total_samples, self.process_length))
|
||||||
|
else:
|
||||||
|
storage_shape = _get_first_shape(self._storage)
|
||||||
|
arrays_shape = _get_first_shape(arrays)
|
||||||
|
if len(storage_shape) > 1 and storage_shape[1] < arrays_shape[1]:
|
||||||
|
# If we get new arrays that are too big too fit, we expand the shape fo the storage
|
||||||
|
self._storage = nested_expand_like(self._storage, arrays_shape[1], padding_index=self.padding_index)
|
||||||
slice_len = self._nested_set_tensors(self._storage, arrays)
|
slice_len = self._nested_set_tensors(self._storage, arrays)
|
||||||
for i in range(self.world_size):
|
for i in range(self.world_size):
|
||||||
self._offsets[i] += slice_len
|
self._offsets[i] += slice_len
|
||||||
@@ -283,7 +342,12 @@ class DistributedTensorGatherer:
|
|||||||
|
|
||||||
slice_len = arrays.shape[0] // self.world_size
|
slice_len = arrays.shape[0] // self.world_size
|
||||||
for i in range(self.world_size):
|
for i in range(self.world_size):
|
||||||
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
|
if len(arrays.shape) == 1:
|
||||||
|
storage[self._offsets[i] : self._offsets[i] + slice_len] = arrays[i * slice_len : (i + 1) * slice_len]
|
||||||
|
else:
|
||||||
|
storage[self._offsets[i] : self._offsets[i] + slice_len, : arrays.shape[1]] = arrays[
|
||||||
|
i * slice_len : (i + 1) * slice_len
|
||||||
|
]
|
||||||
return slice_len
|
return slice_len
|
||||||
|
|
||||||
def finalize(self):
|
def finalize(self):
|
||||||
|
|||||||
@@ -73,6 +73,22 @@ class RegressionDataset:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicShapesDataset:
|
||||||
|
def __init__(self, length=64, seed=42, batch_size=8):
|
||||||
|
self.length = length
|
||||||
|
np.random.seed(seed)
|
||||||
|
sizes = np.random.randint(1, 20, (length // batch_size,))
|
||||||
|
# For easy batching, we make every batch_size consecutive samples the same size.
|
||||||
|
self.xs = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
|
||||||
|
self.ys = [np.random.normal(size=(s,)) for s in sizes.repeat(batch_size)]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def __getitem__(self, i):
|
||||||
|
return {"input_x": self.xs[i], "labels": self.ys[i]}
|
||||||
|
|
||||||
|
|
||||||
class AlmostAccuracy:
|
class AlmostAccuracy:
|
||||||
def __init__(self, thresh=0.25):
|
def __init__(self, thresh=0.25):
|
||||||
self.thresh = thresh
|
self.thresh = thresh
|
||||||
@@ -282,7 +298,7 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu))
|
self.assertEqual(len(trainer.get_train_dataloader()), 66 // (16 * n_gpu))
|
||||||
self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu))
|
self.assertEqual(len(trainer.get_eval_dataloader()), 74 // (32 * n_gpu))
|
||||||
|
|
||||||
# Check passing a new dataset for evaluation wors
|
# Check passing a new dataset for evaluation works
|
||||||
new_eval_dataset = RegressionDataset(length=128)
|
new_eval_dataset = RegressionDataset(length=128)
|
||||||
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
|
self.assertEqual(len(trainer.get_eval_dataloader(new_eval_dataset)), 128 // (32 * n_gpu))
|
||||||
|
|
||||||
@@ -340,6 +356,42 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||||||
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
self.assertTrue(np.array_equal(labels[0], trainer.eval_dataset.ys[0]))
|
||||||
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
self.assertTrue(np.array_equal(labels[1], trainer.eval_dataset.ys[1]))
|
||||||
|
|
||||||
|
def test_dynamic_shapes(self):
|
||||||
|
eval_dataset = DynamicShapesDataset(batch_size=self.batch_size)
|
||||||
|
model = RegressionModel(a=2, b=1)
|
||||||
|
args = TrainingArguments("./regression")
|
||||||
|
trainer = Trainer(model, args, eval_dataset=eval_dataset)
|
||||||
|
|
||||||
|
# Check evaluation can run to completion
|
||||||
|
_ = trainer.evaluate()
|
||||||
|
|
||||||
|
# Check predictions
|
||||||
|
preds = trainer.predict(eval_dataset)
|
||||||
|
for expected, seen in zip(eval_dataset.ys, preds.label_ids):
|
||||||
|
self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
|
||||||
|
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
||||||
|
|
||||||
|
for expected, seen in zip(eval_dataset.xs, preds.predictions):
|
||||||
|
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
|
||||||
|
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
||||||
|
|
||||||
|
# Same tests with eval accumulation
|
||||||
|
args = TrainingArguments("./regression", eval_accumulation_steps=2)
|
||||||
|
trainer = Trainer(model, args, eval_dataset=eval_dataset)
|
||||||
|
|
||||||
|
# Check evaluation can run to completion
|
||||||
|
_ = trainer.evaluate()
|
||||||
|
|
||||||
|
# Check predictions
|
||||||
|
preds = trainer.predict(eval_dataset)
|
||||||
|
for expected, seen in zip(eval_dataset.ys, preds.label_ids):
|
||||||
|
self.assertTrue(np.array_equal(expected, seen[: expected.shape[0]]))
|
||||||
|
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
||||||
|
|
||||||
|
for expected, seen in zip(eval_dataset.xs, preds.predictions):
|
||||||
|
self.assertTrue(np.array_equal(2 * expected + 1, seen[: expected.shape[0]]))
|
||||||
|
self.assertTrue(np.all(seen[expected.shape[0] :] == -100))
|
||||||
|
|
||||||
@require_datasets
|
@require_datasets
|
||||||
def test_trainer_with_datasets(self):
|
def test_trainer_with_datasets(self):
|
||||||
import datasets
|
import datasets
|
||||||
|
|||||||
Reference in New Issue
Block a user