Upstream (and rename) sortish sampler (#9574)
* Upstream (and rename) sortish sampler * Use proper sampler * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -25,7 +25,12 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers.modeling_outputs import SequenceClassifierOutput
|
||||
from transformers.trainer_pt_utils import DistributedTensorGatherer, LabelSmoother
|
||||
from transformers.trainer_pt_utils import (
|
||||
DistributedLengthGroupedSampler,
|
||||
DistributedTensorGatherer,
|
||||
LabelSmoother,
|
||||
LengthGroupedSampler,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -87,3 +92,28 @@ class TrainerUtilsTest(unittest.TestCase):
|
||||
log_probs[2, 3] = 0.0
|
||||
expected_loss = (1 - epsilon) * loss + epsilon * log_probs.sum() / (num_labels * 17)
|
||||
self.assertTrue(torch.allclose(label_smoothed_loss, expected_loss))
|
||||
|
||||
def test_group_by_length(self):
|
||||
# Get some inputs of random lengths
|
||||
lengths = torch.randint(0, 25, (100,)).tolist()
|
||||
# Put one bigger than the others to check it ends up in first position
|
||||
lengths[32] = 50
|
||||
|
||||
indices = list(LengthGroupedSampler(lengths, 4, lengths=lengths))
|
||||
# The biggest element should be first
|
||||
self.assertEqual(lengths[indices[0]], 50)
|
||||
# The indices should be a permutation of range(100)
|
||||
self.assertEqual(list(sorted(indices)), list(range(100)))
|
||||
|
||||
def test_distributed_length_grouped(self):
|
||||
# Get some inputs of random lengths
|
||||
lengths = torch.randint(0, 25, (100,)).tolist()
|
||||
# Put one bigger than the others to check it ends up in first position
|
||||
lengths[32] = 50
|
||||
|
||||
indices_process_0 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 0, lengths=lengths))
|
||||
indices_process_1 = list(DistributedLengthGroupedSampler(lengths, 4, 2, 1, lengths=lengths))
|
||||
# The biggest element should be first
|
||||
self.assertEqual(lengths[indices_process_0[0]], 50)
|
||||
# The indices should be a permutation of range(100)
|
||||
self.assertEqual(list(sorted(indices_process_0 + indices_process_1)), list(range(100)))
|
||||
|
||||
Reference in New Issue
Block a user