examples/seq2seq/__init__.py mutates sys.path (#7194)
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.insert(1, os.path.dirname(os.path.realpath(__file__)))
|
||||
|
||||
@@ -10,28 +10,12 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lightning_base import generic_train
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule, TranslationModule
|
||||
from .finetune import main as ft_main
|
||||
from .initialization_utils import copy_layers, init_student
|
||||
from .utils import (
|
||||
any_requires_grad,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
freeze_params,
|
||||
label_smoothed_nll_loss,
|
||||
pickle_load,
|
||||
use_task_specific_params,
|
||||
)
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import copy_layers, init_student
|
||||
from lightning_base import generic_train
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5Config, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
from utils import (
|
||||
any_requires_grad,
|
||||
assert_all_frozen,
|
||||
|
||||
@@ -12,32 +12,10 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from lightning_base import BaseTransformer, add_generic_args, generic_train
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.modeling_bart import shift_tokens_right
|
||||
|
||||
|
||||
try:
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from .utils import (
|
||||
ROUGE_KEYS,
|
||||
LegacySeq2SeqDataset,
|
||||
Seq2SeqDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
)
|
||||
except ImportError:
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from utils import (
|
||||
ROUGE_KEYS,
|
||||
LegacySeq2SeqDataset,
|
||||
@@ -56,6 +34,7 @@ except ImportError:
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -11,23 +11,6 @@ from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
Seq2SeqDataset,
|
||||
calculate_bleu,
|
||||
calculate_rouge,
|
||||
lmap,
|
||||
load_json,
|
||||
parse_numeric_n_bool_cl_kwargs,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
write_txt_file,
|
||||
)
|
||||
except ImportError:
|
||||
from utils import (
|
||||
Seq2SeqDataset,
|
||||
calculate_bleu,
|
||||
@@ -41,6 +24,9 @@ except ImportError:
|
||||
)
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def eval_data_dir(
|
||||
data_dir,
|
||||
save_dir: str,
|
||||
|
||||
@@ -11,14 +11,11 @@ import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
try:
|
||||
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||
except ImportError:
|
||||
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@@ -4,10 +4,6 @@ import operator
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
try:
|
||||
from .run_eval import datetime_now, run_generate
|
||||
except ImportError:
|
||||
from run_eval import datetime_now, run_generate
|
||||
|
||||
|
||||
|
||||
@@ -10,13 +10,12 @@ import pytorch_lightning as pl
|
||||
import timeout_decorator
|
||||
import torch
|
||||
|
||||
from distillation import BartSummarizationDistiller, distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
||||
from transformers import BartForConditionalGeneration, MarianMTModel
|
||||
from transformers.testing_utils import slow
|
||||
|
||||
from .distillation import BartSummarizationDistiller, distill_main
|
||||
from .finetune import SummarizationModule, main
|
||||
from .test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
|
||||
from .utils import load_json
|
||||
from utils import load_json
|
||||
|
||||
|
||||
MODEL_NAME = MBART_TINY
|
||||
|
||||
@@ -12,16 +12,15 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main, evaluate_checkpoint
|
||||
from finetune import SummarizationModule, main
|
||||
from run_eval import generate_summaries_or_translations, run_generate
|
||||
from run_eval_search import run_search
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow
|
||||
|
||||
from .convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from .distillation import distill_main, evaluate_checkpoint
|
||||
from .finetune import SummarizationModule, main
|
||||
from .run_eval import generate_summaries_or_translations, run_generate
|
||||
from .run_eval_search import run_search
|
||||
from .utils import label_smoothed_nll_loss, lmap, load_json
|
||||
from utils import label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
|
||||
Reference in New Issue
Block a user