[s2s] distill t5-large -> t5-small (#8376)

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Sumithra Bhakthavatsalam
2020-11-11 14:58:45 -08:00
committed by GitHub
parent a5b682329c
commit 81ebd70671
4 changed files with 108 additions and 67 deletions

View File

@@ -9,7 +9,7 @@ import pytorch_lightning as pl
import timeout_decorator
import torch
from distillation import BartSummarizationDistiller, distill_main
from distillation import SummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from transformers import MarianMTModel
from transformers.file_utils import cached_path
@@ -170,7 +170,7 @@ class TestDistilMarianNoTeacher(TestCasePlus):
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = BartSummarizationDistiller.add_model_specific_args(parser, os.getcwd())
parser = SummarizationDistiller.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
# assert args.gpus == gpus THIS BREAKS for multi_gpu