Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -8,9 +8,9 @@ from unittest.mock import patch
|
||||
import pytorch_lightning as pl
|
||||
import timeout_decorator
|
||||
import torch
|
||||
|
||||
from distillation import SummarizationDistiller, distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
|
||||
from transformers import MarianMTModel
|
||||
from transformers.file_utils import cached_path
|
||||
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
|
||||
|
||||
@@ -2,6 +2,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from make_student import create_student_by_copying_alternating_layers
|
||||
|
||||
from transformers import AutoConfig
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
|
||||
@@ -5,18 +5,18 @@ import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import lightning_base
|
||||
import pytest
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
import lightning_base
|
||||
from convert_pl_checkpoint_to_hf import convert_pl_to_hf
|
||||
from distillation import distill_main
|
||||
from finetune import SummarizationModule, main
|
||||
from huggingface_hub import list_models
|
||||
from parameterized import parameterized
|
||||
from run_eval import generate_summaries_or_translations
|
||||
from torch import nn
|
||||
|
||||
from transformers import AutoConfig, AutoModelForSeq2SeqLM
|
||||
from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_gpu, slow
|
||||
from utils import label_smoothed_nll_loss, lmap, load_json
|
||||
|
||||
@@ -98,7 +98,6 @@ class TestSummarizationDistillerMultiGPU(TestCasePlus):
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_multi_gpu(self):
|
||||
|
||||
updates = dict(
|
||||
no_teacher=True,
|
||||
freeze_encoder=True,
|
||||
|
||||
@@ -9,11 +9,11 @@ from typing import List # noqa: F401
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from finetune import SummarizationModule, TranslationModule
|
||||
from finetune import main as ft_main
|
||||
from make_student import create_student_by_copying_alternating_layers, get_layers_to_supervise
|
||||
from torch import nn
|
||||
|
||||
from transformers import AutoModelForSeq2SeqLM, MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from utils import calculate_bleu, check_output_dir, freeze_params, label_smoothed_nll_loss, use_task_specific_params
|
||||
|
||||
@@ -13,10 +13,10 @@ from typing import Dict, List, Tuple
|
||||
import numpy as np
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from callbacks import Seq2SeqLoggingCallback, get_checkpoint_callback, get_early_stopping_callback
|
||||
from transformers import MBartTokenizer, T5ForConditionalGeneration
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
from utils import (
|
||||
|
||||
@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
**config_kwargs,
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
@@ -346,7 +346,7 @@ def generic_train(
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
**extra_train_kwargs,
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def create_student_by_copying_alternating_layers(
|
||||
copy_first_teacher_layers=False,
|
||||
e_layers_to_copy=None,
|
||||
d_layers_to_copy=None,
|
||||
**extra_config_kwargs
|
||||
**extra_config_kwargs,
|
||||
) -> Tuple[PreTrainedModel, List[int], List[int]]:
|
||||
"""Make a student by copying alternating layers from a teacher, save it to save_path.
|
||||
Args:
|
||||
@@ -107,7 +107,6 @@ def create_student_by_copying_alternating_layers(
|
||||
AutoTokenizer.from_pretrained(teacher).save_pretrained(save_path) # purely for convenience
|
||||
teacher = AutoModelForSeq2SeqLM.from_pretrained(teacher).eval()
|
||||
else:
|
||||
|
||||
assert isinstance(teacher, PreTrainedModel), f"teacher must be a model or string got type {type(teacher)}"
|
||||
init_kwargs = teacher.config.to_diff_dict()
|
||||
|
||||
|
||||
@@ -15,10 +15,10 @@ import torch
|
||||
import torch.distributed as dist
|
||||
from rouge_score import rouge_scorer, scoring
|
||||
from sacrebleu import corpus_bleu
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from torch import nn
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
from sentence_splitter import add_newline_to_end_of_each_sentence
|
||||
from transformers import BartTokenizer, EvalPrediction, PreTrainedTokenizer, T5Tokenizer
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.models.bart.modeling_bart import shift_tokens_right
|
||||
@@ -115,7 +115,7 @@ class AbstractSeq2SeqDataset(Dataset):
|
||||
type_path="train",
|
||||
n_obs=None,
|
||||
prefix="",
|
||||
**dataset_kwargs
|
||||
**dataset_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.src_file = Path(data_dir).joinpath(type_path + ".source")
|
||||
|
||||
Reference in New Issue
Block a user