Update repo to isort v5 (#6686)
* Run new isort * More changes * Update CI, CONTRIBUTING and benchmarks
This commit is contained in:
@@ -20,8 +20,8 @@ from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import tqdm
|
||||
from filelock import FileLock
|
||||
|
||||
from filelock import FileLock
|
||||
from transformers import (
|
||||
BartTokenizer,
|
||||
BartTokenizerFast,
|
||||
|
||||
@@ -26,8 +26,8 @@ from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
import tqdm
|
||||
from filelock import FileLock
|
||||
|
||||
from filelock import FileLock
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
|
||||
@@ -44,9 +44,10 @@ def evaluate(args):
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
|
||||
import rouge
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
|
||||
@@ -15,27 +15,27 @@ from transformers import BartConfig, BartForConditionalGeneration, MBartTokenize
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule, TranslationModule
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
from .finetune import main as ft_main
|
||||
from .initialization_utils import copy_layers, init_student
|
||||
from .utils import (
|
||||
any_requires_grad,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
freeze_params,
|
||||
pickle_load,
|
||||
use_task_specific_params,
|
||||
)
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import init_student, copy_layers
|
||||
from initialization_utils import copy_layers, init_student
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
freeze_params,
|
||||
pickle_load,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,44 +17,43 @@ from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGenera
|
||||
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
from .utils import (
|
||||
ROUGE_KEYS,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
lmap,
|
||||
calculate_bleu_score,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
label_smoothed_nll_loss,
|
||||
use_task_specific_params,
|
||||
)
|
||||
except ImportError:
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from utils import (
|
||||
ROUGE_KEYS,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
|
||||
from .utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
||||
except ImportError:
|
||||
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
|
||||
from utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
@@ -35,8 +35,8 @@ sys.path.extend(SRC_DIRS)
|
||||
if SRC_DIRS is not None:
|
||||
import run_generation
|
||||
import run_glue
|
||||
import run_pl_glue
|
||||
import run_language_modeling
|
||||
import run_pl_glue
|
||||
import run_squad
|
||||
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ from enum import Enum
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from filelock import FileLock
|
||||
|
||||
from transformers import PreTrainedTokenizer, is_tf_available, is_torch_available
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user