From e238e3d55ab5df326c40b240d3d6c1f6bec7d641 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Fri, 17 Jul 2020 01:53:25 -0400 Subject: [PATCH] [seq2seq] Don't copy self.source in sortishsampler (#5818) --- examples/seq2seq/utils.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/examples/seq2seq/utils.py b/examples/seq2seq/utils.py index 0938f0f373..2c1f9aebf5 100644 --- a/examples/seq2seq/utils.py +++ b/examples/seq2seq/utils.py @@ -144,16 +144,9 @@ class SummarizationDataset(Dataset): batch = {"input_ids": source_ids, "attention_mask": source_mask, "decoder_input_ids": y} return batch - @property - def src_lens(self): # Can delete? - return lmap(len, self.source) - - @property - def tgt_lens(self): - return lmap(len, self.target) - def make_sortish_sampler(self, batch_size): - return SortishSampler(self.source, batch_size) + lens = [x["input_ids"].ne(self.pad_token_id).sum() for x in self.source] + return SortishSampler(lens, batch_size) class SortishSampler(Sampler): @@ -163,7 +156,7 @@ class SortishSampler(Sampler): self.data, self.bs = data, batch_size def key(self, i): - return len(self.data[i]) + return self.data[i] def __len__(self) -> int: return len(self.data)