[s2s] Delete useless method, log tokens_per_batch (#6081)

This commit is contained in:
Sam Shleifer
2020-07-28 11:24:23 -04:00
committed by GitHub
parent dc4755c6d5
commit dafa296c95
2 changed files with 14 additions and 15 deletions

View File

@@ -128,12 +128,6 @@ class Seq2SeqDataset(Dataset):
def get_char_lens(data_file):
return [len(x) for x in Path(data_file).open().readlines()]
@staticmethod
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
return source_ids, source_mask, y
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
input_ids = torch.stack([x["input_ids"] for x in batch])
masks = torch.stack([x["attention_mask"] for x in batch])