Seq2SeqDataset uses linecache to save memory by @Pradhy729 (#5792)
Co-authored-by: Pradhy729 <49659913+Pradhy729@users.noreply.github.com>
This commit is contained in:
@@ -15,28 +15,15 @@ from transformers import AdamW, BartConfig, BartForConditionalGeneration, T5Conf
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
from .finetune import main as ft_main
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import init_student, copy_layers
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
SummarizationDataset,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
)
|
||||
from utils import use_task_specific_params, pickle_load, freeze_params, assert_all_frozen, any_requires_grad
|
||||
|
||||
|
||||
class BartSummarizationDistiller(SummarizationModule):
|
||||
@@ -115,11 +102,6 @@ class BartSummarizationDistiller(SummarizationModule):
|
||||
if self.different_encoder:
|
||||
copy_layers(teacher.encoder.block, student.encoder.block, e_layers_to_copy)
|
||||
|
||||
def get_dataset(self, type_path) -> SummarizationDataset:
|
||||
n_obs = self.n_obs[type_path]
|
||||
dataset = SummarizationDataset(self.tokenizer, type_path=type_path, n_obs=n_obs, **self.dataset_kwargs)
|
||||
return dataset
|
||||
|
||||
def calc_mse_loss(self, teacher_outputs: torch.Tensor, student_outputs: torch.Tensor, mask) -> torch.FloatTensor:
|
||||
if mask is not None:
|
||||
# mask has False at padding_idx
|
||||
|
||||
Reference in New Issue
Block a user