[seq2seq] make it easier to run the scripts (#7274)

This commit is contained in:
Stas Bekman
2020-09-24 12:23:48 -07:00
committed by GitHub
parent 8d3bb781ee
commit eadd870b2f
18 changed files with 50 additions and 31 deletions

9
examples/seq2seq/distillation.py Normal file → Executable file
View File

@@ -1,6 +1,9 @@
#!/usr/bin/env python
import argparse
import gc
import os
import sys
import warnings
from pathlib import Path
from typing import List
@@ -13,7 +16,6 @@ from torch.nn import functional as F
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import (
@@ -27,6 +29,11 @@ from utils import (
)
# need the parent dir module
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
from lightning_base import generic_train # noqa
class BartSummarizationDistiller(SummarizationModule):
"""Supports Bart, Pegasus and other models that inherit from Bart."""