examples/seq2seq/__init__.py mutates sys.path (#7194)

This commit is contained in:
Stas Bekman
2020-09-20 13:54:42 -07:00
committed by GitHub
parent a4faeceaed
commit 7cbf0f722d
8 changed files with 58 additions and 113 deletions

View File

@@ -0,0 +1,5 @@
import os
import sys
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))

View File

@@ -10,28 +10,12 @@ import torch
from torch import nn
from torch.nn import functional as F
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try:
from .finetune import SummarizationModule, TranslationModule
from .finetune import main as ft_main
from .initialization_utils import copy_layers, init_student
from .utils import (
any_requires_grad,
assert_all_frozen,
calculate_bleu,
freeze_params,
label_smoothed_nll_loss,
pickle_load,
use_task_specific_params,
)
except ImportError:
from finetune import SummarizationModule, TranslationModule
from finetune import main as ft_main
from initialization_utils import copy_layers, init_student
from lightning_base import generic_train
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
from utils import (
any_requires_grad,
assert_all_frozen,

View File

@@ -12,32 +12,10 @@ import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from lightning_base import BaseTransformer, add_generic_args, generic_train
from transformers import MBartTokenizer, T5ForConditionalGeneration
from transformers.modeling_bart import shift_tokens_right
try:
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from .utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
Seq2SeqDataset,
assert_all_frozen,
calculate_bleu,
calculate_rouge,
flatten_list,
freeze_params,
get_git_info,
label_smoothed_nll_loss,
lmap,
pickle_save,
save_git_info,
save_json,
use_task_specific_params,
)
except ImportError:
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
from utils import (
ROUGE_KEYS,
LegacySeq2SeqDataset,
@@ -56,6 +34,7 @@ except ImportError:
use_task_specific_params,
)
logger = logging.getLogger(__name__)

View File

@@ -11,23 +11,6 @@ from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try:
from .utils import (
Seq2SeqDataset,
calculate_bleu,
calculate_rouge,
lmap,
load_json,
parse_numeric_n_bool_cl_kwargs,
save_json,
use_task_specific_params,
write_txt_file,
)
except ImportError:
from utils import (
Seq2SeqDataset,
calculate_bleu,
@@ -41,6 +24,9 @@ except ImportError:
)
logger = getLogger(__name__)
def eval_data_dir(
data_dir,
save_dir: str,

View File

@@ -11,14 +11,11 @@ import torch
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
logger = getLogger(__name__)
try:
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

View File

@@ -4,10 +4,6 @@ import operator
import sys
from collections import OrderedDict
try:
from .run_eval import datetime_now, run_generate
except ImportError:
from run_eval import datetime_now, run_generate

View File

@@ -10,13 +10,12 @@ import pytorch_lightning as pl
import timeout_decorator
import torch
from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow
from .distillation import BartSummarizationDistiller, distill_main
from .finetune import SummarizationModule, main
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from .utils import load_json
from utils import load_json
MODEL_NAME = MBART_TINY

View File

@@ -12,16 +12,15 @@ import pytorch_lightning as pl
import torch
import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main, evaluate_checkpoint
from finetune import SummarizationModule, main
from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .run_eval import generate_summaries_or_translations, run_generate
from .run_eval_search import run_search
from .utils import label_smoothed_nll_loss, lmap, load_json
from utils import label_smoothed_nll_loss, lmap, load_json
logging.basicConfig(level=logging.DEBUG)