Update repo to isort v5 (#6686)
* Run new isort * More changes * Update CI, CONTRIBUTING and benchmarks
This commit is contained in:
@@ -44,9 +44,10 @@ def evaluate(args):
|
||||
reference_summaries = []
|
||||
generated_summaries = []
|
||||
|
||||
import rouge
|
||||
import nltk
|
||||
|
||||
import rouge
|
||||
|
||||
nltk.download("punkt")
|
||||
rouge_evaluator = rouge.Rouge(
|
||||
metrics=["rouge-n", "rouge-l"],
|
||||
|
||||
@@ -15,27 +15,27 @@ from transformers import BartConfig, BartForConditionalGeneration, MBartTokenize
|
||||
|
||||
try:
|
||||
from .finetune import SummarizationModule, TranslationModule
|
||||
from .initialization_utils import init_student, copy_layers
|
||||
from .utils import (
|
||||
use_task_specific_params,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
calculate_bleu_score,
|
||||
)
|
||||
from .finetune import main as ft_main
|
||||
from .initialization_utils import copy_layers, init_student
|
||||
from .utils import (
|
||||
any_requires_grad,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
freeze_params,
|
||||
pickle_load,
|
||||
use_task_specific_params,
|
||||
)
|
||||
except ImportError:
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from initialization_utils import init_student, copy_layers
|
||||
from initialization_utils import copy_layers, init_student
|
||||
from utils import (
|
||||
use_task_specific_params,
|
||||
pickle_load,
|
||||
freeze_params,
|
||||
assert_all_frozen,
|
||||
any_requires_grad,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
freeze_params,
|
||||
pickle_load,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -17,44 +17,43 @@ from transformers import MarianTokenizer, MBartTokenizer, T5ForConditionalGenera
|
||||
|
||||
|
||||
try:
|
||||
from .utils import (
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
lmap,
|
||||
flatten_list,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
label_smoothed_nll_loss,
|
||||
)
|
||||
|
||||
from .callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
except ImportError:
|
||||
from utils import (
|
||||
from .utils import (
|
||||
ROUGE_KEYS,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
use_task_specific_params,
|
||||
lmap,
|
||||
calculate_bleu_score,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
freeze_params,
|
||||
calculate_rouge,
|
||||
get_git_info,
|
||||
ROUGE_KEYS,
|
||||
calculate_bleu_score,
|
||||
label_smoothed_nll_loss,
|
||||
use_task_specific_params,
|
||||
)
|
||||
except ImportError:
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from utils import (
|
||||
ROUGE_KEYS,
|
||||
Seq2SeqDataset,
|
||||
TranslationDataset,
|
||||
assert_all_frozen,
|
||||
calculate_bleu_score,
|
||||
calculate_rouge,
|
||||
flatten_list,
|
||||
freeze_params,
|
||||
get_git_info,
|
||||
label_smoothed_nll_loss,
|
||||
lmap,
|
||||
pickle_save,
|
||||
save_git_info,
|
||||
save_json,
|
||||
use_task_specific_params,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
|
||||
|
||||
|
||||
try:
|
||||
from .utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
|
||||
from .utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
||||
except ImportError:
|
||||
from utils import calculate_rouge, use_task_specific_params, calculate_bleu_score, trim_batch
|
||||
from utils import calculate_bleu_score, calculate_rouge, trim_batch, use_task_specific_params
|
||||
|
||||
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user