Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)

Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
Sam Shleifer
2020-07-18 13:57:33 -04:00
committed by GitHub
parent 4b506a37e3
commit 09a2f40684
6 changed files with 182 additions and 170 deletions

View File

@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf
try:
from .finetune import SummarizationModule
from .initialization_utils import init_student, copy_layers
from .utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from .finetune import main as ft_main
from .initialization_utils import init_student, copy_layers
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
except ImportError:
from finetune import SummarizationModule
from finetune import main as ft_main
from initialization_utils import init_student, copy_layers
from utils import (
use_task_specific_params,
SummarizationDataset,
pickle_load,
freeze_params,
assert_all_frozen,
any_requires_grad,
)
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
class BartSummarizationDistiller(SummarizationModule):
@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule):
if self.different_encoder:
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
def get_dataset(self, type_path) -> SummarizationDataset:
n_obs = self.n_obs[type_path]
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
return dataset
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
if mask is not None:
# mask has False at padding_idx