[s2s] Delete useless method, log tokens_per_batch (#6081)
This commit is contained in:
@@ -128,12 +128,6 @@ class Seq2SeqDataset(Dataset):
|
||||
def get_char_lens(data_file):
|
||||
return [len(x) for x in Path(data_file).open().readlines()]
|
||||
|
||||
@staticmethod
|
||||
def trim_seq2seq_batch(batch, pad_token_id) -> tuple:
|
||||
y = trim_batch(batch["decoder_input_ids"], pad_token_id)
|
||||
source_ids, source_mask = trim_batch(batch["input_ids"], pad_token_id, attention_mask=batch["attention_mask"])
|
||||
return source_ids, source_mask, y
|
||||
|
||||
def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
|
||||
input_ids = torch.stack([x["input_ids"] for x in batch])
|
||||
masks = torch.stack([x["attention_mask"] for x in batch])
|
||||
|
||||
Reference in New Issue
Block a user