[seq2seq] make it easier to run the scripts (#7274)
This commit is contained in:
9
examples/seq2seq/distillation.py
Normal file → Executable file
9
examples/seq2seq/distillation.py
Normal file → Executable file
@@ -1,6 +1,9 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
@@ -13,7 +16,6 @@ from torch.nn import functional as F
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import copy_layers, init_student
|
||||
from lightning_base import generic_train
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from utils import (
|
||||
@@ -27,6 +29,11 @@ from utils import (
|
||||
)
|
||||
|
||||
|
||||
# need the parent dir module
|
||||
sys.path.insert(2, str(Path(__file__).resolve().parents[1]))
|
||||
from lightning_base import generic_train # noqa
|
||||
|
||||
|
||||
class BartSummarizationDistiller(SummarizationModule):
|
||||
"""Supports Bart, Pegasus and other models that inherit from Bart."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user