From 7cbf0f722d23440f3342aafc27697b50ead5996b Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Sun, 20 Sep 2020 13:54:42 -0700 Subject: [PATCH] examples/seq2seq/__init__.py mutates sys.path (#7194) --- examples/seq2seq/__init__.py | 5 ++ examples/seq2seq/distillation.py | 40 +++++----------- examples/seq2seq/finetune.py | 57 +++++++---------------- examples/seq2seq/run_distributed_eval.py | 36 +++++--------- examples/seq2seq/run_eval.py | 5 +- examples/seq2seq/run_eval_search.py | 6 +-- examples/seq2seq/test_bash_script.py | 9 ++-- examples/seq2seq/test_seq2seq_examples.py | 13 +++--- 8 files changed, 58 insertions(+), 113 deletions(-) diff --git a/examples/seq2seq/__init__.py b/examples/seq2seq/__init__.py index e69de29bb2..3cee09bb7f 100644 --- a/examples/seq2seq/__init__.py +++ b/examples/seq2seq/__init__.py @@ -0,0 +1,5 @@ +import os +import sys + + +sys.path.insert(1, os.path.dirname(os.path.realpath(__file__))) diff --git a/examples/seq2seq/distillation.py b/examples/seq2seq/distillation.py index 8d4611591c..3b1ce10d0d 100644 --- a/examples/seq2seq/distillation.py +++ b/examples/seq2seq/distillation.py @@ -10,37 +10,21 @@ import torch from torch import nn from torch.nn import functional as F +from finetune import SummarizationModule, TranslationModule +from finetune import main as ft_main +from initialization_utils import copy_layers, init_student from lightning_base import generic_train from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration from transformers.modeling_bart import shift_tokens_right - - -try: - from .finetune import SummarizationModule, TranslationModule - from .finetune import main as ft_main - from .initialization_utils import copy_layers, init_student - from .utils import ( - any_requires_grad, - assert_all_frozen, - calculate_bleu, - freeze_params, - label_smoothed_nll_loss, - pickle_load, - use_task_specific_params, - ) -except ImportError: - from finetune import SummarizationModule, TranslationModule - from finetune import main as ft_main - from initialization_utils import copy_layers, init_student - from utils import ( - any_requires_grad, - assert_all_frozen, - calculate_bleu, - freeze_params, - label_smoothed_nll_loss, - pickle_load, - use_task_specific_params, - ) +from utils import ( + any_requires_grad, + assert_all_frozen, + calculate_bleu, + freeze_params, + label_smoothed_nll_loss, + pickle_load, + use_task_specific_params, +) class BartSummarizationDistiller(SummarizationModule): diff --git a/examples/seq2seq/finetune.py b/examples/seq2seq/finetune.py index 4835e58439..f54f15c1d5 100644 --- a/examples/seq2seq/finetune.py +++ b/examples/seq2seq/finetune.py @@ -12,50 +12,29 @@ import pytorch_lightning as pl import torch from torch.utils.data import DataLoader +from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback from lightning_base import BaseTransformer, add_generic_args, generic_train from transformers import MBartTokenizer, T5ForConditionalGeneration from transformers.modeling_bart import shift_tokens_right +from utils import ( + ROUGE_KEYS, + LegacySeq2SeqDataset, + Seq2SeqDataset, + assert_all_frozen, + calculate_bleu, + calculate_rouge, + flatten_list, + freeze_params, + get_git_info, + label_smoothed_nll_loss, + lmap, + pickle_save, + save_git_info, + save_json, + use_task_specific_params, +) -try: - from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback - from .utils import ( - ROUGE_KEYS, - LegacySeq2SeqDataset, - Seq2SeqDataset, - assert_all_frozen, - calculate_bleu, - calculate_rouge, - flatten_list, - freeze_params, - get_git_info, - label_smoothed_nll_loss, - lmap, - pickle_save, - save_git_info, - save_json, - use_task_specific_params, - ) -except ImportError: - from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback - from utils import ( - ROUGE_KEYS, - LegacySeq2SeqDataset, - Seq2SeqDataset, - assert_all_frozen, - calculate_bleu, - calculate_rouge, - flatten_list, - freeze_params, - get_git_info, - label_smoothed_nll_loss, - lmap, - pickle_save, - save_git_info, - save_json, - use_task_specific_params, - ) - logger = logging.getLogger(__name__) diff --git a/examples/seq2seq/run_distributed_eval.py b/examples/seq2seq/run_distributed_eval.py index 4b25db2149..e8218e1917 100644 --- a/examples/seq2seq/run_distributed_eval.py +++ b/examples/seq2seq/run_distributed_eval.py @@ -11,35 +11,21 @@ from torch.utils.data import DataLoader from tqdm import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from utils import ( + Seq2SeqDataset, + calculate_bleu, + calculate_rouge, + lmap, + load_json, + parse_numeric_n_bool_cl_kwargs, + save_json, + use_task_specific_params, + write_txt_file, +) logger = getLogger(__name__) -try: - from .utils import ( - Seq2SeqDataset, - calculate_bleu, - calculate_rouge, - lmap, - load_json, - parse_numeric_n_bool_cl_kwargs, - save_json, - use_task_specific_params, - write_txt_file, - ) -except ImportError: - from utils import ( - Seq2SeqDataset, - calculate_bleu, - calculate_rouge, - lmap, - load_json, - parse_numeric_n_bool_cl_kwargs, - save_json, - use_task_specific_params, - write_txt_file, - ) - def eval_data_dir( data_dir, diff --git a/examples/seq2seq/run_eval.py b/examples/seq2seq/run_eval.py index 4b2a551e7a..09ff4c9a53 100644 --- a/examples/seq2seq/run_eval.py +++ b/examples/seq2seq/run_eval.py @@ -11,14 +11,11 @@ import torch from tqdm import tqdm from transformers import AutoModelForSeq2SeqLM, AutoTokenizer +from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params logger = getLogger(__name__) -try: - from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params -except ImportError: - from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/examples/seq2seq/run_eval_search.py b/examples/seq2seq/run_eval_search.py index ae221d37b6..2a819e169f 100644 --- a/examples/seq2seq/run_eval_search.py +++ b/examples/seq2seq/run_eval_search.py @@ -4,11 +4,7 @@ import operator import sys from collections import OrderedDict - -try: - from .run_eval import datetime_now, run_generate -except ImportError: - from run_eval import datetime_now, run_generate +from run_eval import datetime_now, run_generate # A table of supported tasks and the list of scores in the order of importance to be sorted by. diff --git a/examples/seq2seq/test_bash_script.py b/examples/seq2seq/test_bash_script.py index 4f20b055b6..7d163d1c35 100644 --- a/examples/seq2seq/test_bash_script.py +++ b/examples/seq2seq/test_bash_script.py @@ -10,13 +10,12 @@ import pytorch_lightning as pl import timeout_decorator import torch +from distillation import BartSummarizationDistiller, distill_main +from finetune import SummarizationModule, main +from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from transformers import BartForConditionalGeneration, MarianMTModel from transformers.testing_utils import slow - -from .distillation import BartSummarizationDistiller, distill_main -from .finetune import SummarizationModule, main -from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY -from .utils import load_json +from utils import load_json MODEL_NAME = MBART_TINY diff --git a/examples/seq2seq/test_seq2seq_examples.py b/examples/seq2seq/test_seq2seq_examples.py index 7fdd7bf152..0f06959183 100644 --- a/examples/seq2seq/test_seq2seq_examples.py +++ b/examples/seq2seq/test_seq2seq_examples.py @@ -12,16 +12,15 @@ import pytorch_lightning as pl import torch import lightning_base +from convert_pl_checkpoint_to_hf import convert_pl_to_hf +from distillation import distill_main, evaluate_checkpoint +from finetune import SummarizationModule, main +from run_eval import generate_summaries_or_translations, run_generate +from run_eval_search import run_search from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers.hf_api import HfApi from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow - -from .convert_pl_checkpoint_to_hf import convert_pl_to_hf -from .distillation import distill_main, evaluate_checkpoint -from .finetune import SummarizationModule, main -from .run_eval import generate_summaries_or_translations, run_generate -from .run_eval_search import run_search -from .utils import label_smoothed_nll_loss, lmap, load_json +from utils import label_smoothed_nll_loss, lmap, load_json logging.basicConfig(level=logging.DEBUG)