Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)

Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-07-18 13:57:33 -04:00
committed by GitHub
parent 4b506a37e3
commit 09a2f40684
6 changed files with 182 additions and 170 deletions

View File

@@ -21,7 +21,6 @@ try:
from .utils import (
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
@@ -32,12 +31,17 @@ try:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
Seq2SeqDataset,
MBartDataset,
)
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
except ImportError:
from utils import (
Seq2SeqDataset,
MBartDataset,
assert_all_frozen,
use_task_specific_params,
SummarizationDataset,
lmap,
flatten_list,
pickle_save,
@@ -48,7 +52,6 @@ except ImportError:
get_git_info,
ROUGE_KEYS,
calculate_bleu_score,
assert_all_frozen,
)
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer):
self.hparams.git_sha = get_git_info()["repo_sha"]
self.num_workers = hparams.num_workers
self.decoder_start_token_id = None
self.dataset_class = Seq2SeqDataset
def freeze_embeds(self):
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer):
def _generative_step(self, batch: dict) -> dict:
pad_token_id = self.tokenizer.pad_token_id
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
t0 = time.time()
generated_ids = self.model.generate(
input_ids=source_ids,
@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer):
def test_epoch_end(self, outputs):
return self.validation_epoch_end(outputs, prefix="test")
def get_dataset(self, type_path) -> SummarizationDataset:
def get_dataset(self, type_path) -> Seq2SeqDataset:
n_obs = self.n_obs[type_path]
max_target_length = self.target_lens[type_path]
dataset = SummarizationDataset(
dataset = self.dataset_class(
self.tokenizer,
type_path=type_path,
n_obs=n_obs,
@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule):
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
if isinstance(self.tokenizer, MBartTokenizer):
self.dataset_class = MBartDataset
def calc_generative_metrics(self, preds, target) -> dict:
return calculate_bleu_score(preds, target)