examples/seq2seq/__init__.py mutates sys.path (#7194)
This commit is contained in:
@@ -0,0 +1,5 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
|||||||
@@ -10,16 +10,13 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from finetune import SummarizationModule, TranslationModule
|
||||||
|
from finetune import main as ft_main
|
||||||
|
from initialization_utils import copy_layers, init_student
|
||||||
from lightning_base import generic_train
|
from lightning_base import generic_train
|
||||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
from utils import (
|
||||||
|
|
||||||
try:
|
|
||||||
from .finetune import SummarizationModule, TranslationModule
|
|
||||||
from .finetune import main as ft_main
|
|
||||||
from .initialization_utils import copy_layers, init_student
|
|
||||||
from .utils import (
|
|
||||||
any_requires_grad,
|
any_requires_grad,
|
||||||
assert_all_frozen,
|
assert_all_frozen,
|
||||||
calculate_bleu,
|
calculate_bleu,
|
||||||
@@ -27,20 +24,7 @@ try:
|
|||||||
label_smoothed_nll_loss,
|
label_smoothed_nll_loss,
|
||||||
pickle_load,
|
pickle_load,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
)
|
)
|
||||||
except ImportError:
|
|
||||||
from finetune import SummarizationModule, TranslationModule
|
|
||||||
from finetune import main as ft_main
|
|
||||||
from initialization_utils import copy_layers, init_student
|
|
||||||
from utils import (
|
|
||||||
any_requires_grad,
|
|
||||||
assert_all_frozen,
|
|
||||||
calculate_bleu,
|
|
||||||
freeze_params,
|
|
||||||
label_smoothed_nll_loss,
|
|
||||||
pickle_load,
|
|
||||||
use_task_specific_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BartSummarizationDistiller(SummarizationModule):
|
class BartSummarizationDistiller(SummarizationModule):
|
||||||
|
|||||||
@@ -12,14 +12,11 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||||
from transformers.modeling_bart import shift_tokens_right
|
from transformers.modeling_bart import shift_tokens_right
|
||||||
|
from utils import (
|
||||||
|
|
||||||
try:
|
|
||||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
|
||||||
from .utils import (
|
|
||||||
ROUGE_KEYS,
|
ROUGE_KEYS,
|
||||||
LegacySeq2SeqDataset,
|
LegacySeq2SeqDataset,
|
||||||
Seq2SeqDataset,
|
Seq2SeqDataset,
|
||||||
@@ -35,26 +32,8 @@ try:
|
|||||||
save_git_info,
|
save_git_info,
|
||||||
save_json,
|
save_json,
|
||||||
use_task_specific_params,
|
use_task_specific_params,
|
||||||
)
|
)
|
||||||
except ImportError:
|
|
||||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
|
||||||
from utils import (
|
|
||||||
ROUGE_KEYS,
|
|
||||||
LegacySeq2SeqDataset,
|
|
||||||
Seq2SeqDataset,
|
|
||||||
assert_all_frozen,
|
|
||||||
calculate_bleu,
|
|
||||||
calculate_rouge,
|
|
||||||
flatten_list,
|
|
||||||
freeze_params,
|
|
||||||
get_git_info,
|
|
||||||
label_smoothed_nll_loss,
|
|
||||||
lmap,
|
|
||||||
pickle_save,
|
|
||||||
save_git_info,
|
|
||||||
save_json,
|
|
||||||
use_task_specific_params,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -11,35 +11,21 @@ from torch.utils.data import DataLoader
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
from utils import (
|
||||||
|
Seq2SeqDataset,
|
||||||
|
calculate_bleu,
|
||||||
|
calculate_rouge,
|
||||||
|
lmap,
|
||||||
|
load_json,
|
||||||
|
parse_numeric_n_bool_cl_kwargs,
|
||||||
|
save_json,
|
||||||
|
use_task_specific_params,
|
||||||
|
write_txt_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
from .utils import (
|
|
||||||
Seq2SeqDataset,
|
|
||||||
calculate_bleu,
|
|
||||||
calculate_rouge,
|
|
||||||
lmap,
|
|
||||||
load_json,
|
|
||||||
parse_numeric_n_bool_cl_kwargs,
|
|
||||||
save_json,
|
|
||||||
use_task_specific_params,
|
|
||||||
write_txt_file,
|
|
||||||
)
|
|
||||||
except ImportError:
|
|
||||||
from utils import (
|
|
||||||
Seq2SeqDataset,
|
|
||||||
calculate_bleu,
|
|
||||||
calculate_rouge,
|
|
||||||
lmap,
|
|
||||||
load_json,
|
|
||||||
parse_numeric_n_bool_cl_kwargs,
|
|
||||||
save_json,
|
|
||||||
use_task_specific_params,
|
|
||||||
write_txt_file,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def eval_data_dir(
|
def eval_data_dir(
|
||||||
data_dir,
|
data_dir,
|
||||||
|
|||||||
@@ -11,14 +11,11 @@ import torch
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||||
|
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||||
|
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
try:
|
|
||||||
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
|
||||||
except ImportError:
|
|
||||||
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
|
||||||
|
|
||||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
|||||||
@@ -4,11 +4,7 @@ import operator
|
|||||||
import sys
|
import sys
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
|
from run_eval import datetime_now, run_generate
|
||||||
try:
|
|
||||||
from .run_eval import datetime_now, run_generate
|
|
||||||
except ImportError:
|
|
||||||
from run_eval import datetime_now, run_generate
|
|
||||||
|
|
||||||
|
|
||||||
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
|
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
|
||||||
|
|||||||
@@ -10,13 +10,12 @@ import pytorch_lightning as pl
|
|||||||
import timeout_decorator
|
import timeout_decorator
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from distillation import BartSummarizationDistiller, distill_main
|
||||||
|
from finetune import SummarizationModule, main
|
||||||
|
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
||||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||||
from transformers.testing_utils import slow
|
from transformers.testing_utils import slow
|
||||||
|
from utils import load_json
|
||||||
from .distillation import BartSummarizationDistiller, distill_main
|
|
||||||
from .finetune import SummarizationModule, main
|
|
||||||
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
|
||||||
from .utils import load_json
|
|
||||||
|
|
||||||
|
|
||||||
MODEL_NAME = MBART_TINY
|
MODEL_NAME = MBART_TINY
|
||||||
|
|||||||
@@ -12,16 +12,15 @@ import pytorch_lightning as pl
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import lightning_base
|
import lightning_base
|
||||||
|
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||||
|
from distillation import distill_main, evaluate_checkpoint
|
||||||
|
from finetune import SummarizationModule, main
|
||||||
|
from run_eval import generate_summaries_or_translations, run_generate
|
||||||
|
from run_eval_search import run_search
|
||||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||||
from transformers.hf_api import HfApi
|
from transformers.hf_api import HfApi
|
||||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
||||||
|
from utils import label_smoothed_nll_loss, lmap, load_json
|
||||||
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
|
||||||
from .distillation import distill_main, evaluate_checkpoint
|
|
||||||
from .finetune import SummarizationModule, main
|
|
||||||
from .run_eval import generate_summaries_or_translations, run_generate
|
|
||||||
from .run_eval_search import run_search
|
|
||||||
from .utils import label_smoothed_nll_loss, lmap, load_json
|
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG)
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
|||||||
Reference in New Issue
Block a user