Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)
Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
@@ -21,7 +21,6 @@ try:
|
||||
from .utils import (
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
@@ -32,12 +31,17 @@ try:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
)
|
||||
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
Seq2SeqDataset,
|
||||
MBartDataset,
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
@@ -48,7 +52,6 @@ except ImportError:
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
assert_all_frozen,
|
||||
)
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback
|
||||
|
||||
@@ -100,6 +103,7 @@ class SummarizationModule(BaseTransformer):
|
||||
self.hparams.git_sha = get_git_info()["repo_sha"]
|
||||
self.num_workers = hparams.num_workers
|
||||
self.decoder_start_token_id = None
|
||||
self.dataset_class = Seq2SeqDataset
|
||||
|
||||
def freeze_embeds(self):
|
||||
"""Freeze token embeddings and positional embeddings for bart, just token embeddings for t5."""
|
||||
@@ -163,7 +167,7 @@ class SummarizationModule(BaseTransformer):
|
||||
|
||||
def _generative_step(self, batch: dict) -> dict:
|
||||
pad_token_id = self.tokenizer.pad_token_id
|
||||
source_ids, source_mask, y = SummarizationDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
source_ids, source_mask, y = Seq2SeqDataset.trim_seq2seq_batch(batch, pad_token_id)
|
||||
t0 = time.time()
|
||||
generated_ids = self.model.generate(
|
||||
input_ids=source_ids,
|
||||
@@ -187,10 +191,10 @@ class SummarizationModule(BaseTransformer):
|
||||
def test_epoch_end(self, outputs):
|
||||
return self.validation_epoch_end(outputs, prefix="test")
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
def get_dataset(self, type_path) -> Seq2SeqDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
max_target_length = self.target_lens[type_path]
|
||||
dataset = SummarizationDataset(
|
||||
dataset = self.dataset_class(
|
||||
self.tokenizer,
|
||||
type_path=type_path,
|
||||
n_obs=n_obs,
|
||||
@@ -303,6 +307,8 @@ class TranslationModule(SummarizationModule):
|
||||
self.dataset_kwargs["tgt_lang"] = hparams.tgt_lang
|
||||
if self.model.config.decoder_start_token_id is None and isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.decoder_start_token_id = self.tokenizer.lang_code_to_id[hparams.tgt_lang]
|
||||
if isinstance(self.tokenizer, MBartTokenizer):
|
||||
self.dataset_class = MBartDataset
|
||||
|
||||
def calc_generative_metrics(self, preds, target) -> dict:
|
||||
return calculate_bleu_score(preds, target)
|
||||
|
||||
Reference in New Issue
Block a user