Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -30,11 +30,11 @@ from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from huggingface_hub.file_download import http_get
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
@@ -108,9 +108,9 @@ if is_accelerate_available():
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||
from torch import nn
|
||||
|
||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||
from transformers import (
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MODEL_MAPPING,
|
||||
@@ -160,6 +160,7 @@ if is_tf_available():
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
@@ -183,7 +184,6 @@ TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-cl
|
||||
|
||||
@require_torch
|
||||
class ModelTesterMixin:
|
||||
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
all_generative_model_classes = ()
|
||||
@@ -417,7 +417,6 @@ class ModelTesterMixin:
|
||||
base_class = base_class[0]
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
if model_class == base_class:
|
||||
continue
|
||||
|
||||
@@ -706,7 +705,6 @@ class ModelTesterMixin:
|
||||
|
||||
# This is copied from `torch/testing/_internal/jit_utils.py::clear_class_registry`
|
||||
def clear_torch_jit_class_registry(self):
|
||||
|
||||
torch._C._jit_clear_class_registry()
|
||||
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
|
||||
# torch 1.8 has no `_clear_class_state` in `torch.jit._state`
|
||||
@@ -1512,7 +1510,6 @@ class ModelTesterMixin:
|
||||
base_model_prefix = model.base_model_prefix
|
||||
|
||||
if hasattr(model, base_model_prefix):
|
||||
|
||||
extra_params = {k: v for k, v in model.named_parameters() if not k.startswith(base_model_prefix)}
|
||||
extra_params.update({k: v for k, v in model.named_buffers() if not k.startswith(base_model_prefix)})
|
||||
# Some models define this as None
|
||||
@@ -1854,7 +1851,6 @@ class ModelTesterMixin:
|
||||
)
|
||||
|
||||
def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
|
||||
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs_dict.items():
|
||||
# skip key that does not exist in tf
|
||||
@@ -1875,7 +1871,6 @@ class ModelTesterMixin:
|
||||
return tf_inputs_dict
|
||||
|
||||
def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
|
||||
|
||||
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
@@ -1907,7 +1902,6 @@ class ModelTesterMixin:
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
@@ -2544,7 +2538,6 @@ class ModelTesterMixin:
|
||||
|
||||
for problem_type in problem_types:
|
||||
with self.subTest(msg=f"Testing {model_class} with {problem_type['title']}"):
|
||||
|
||||
config.problem_type = problem_type["title"]
|
||||
config.num_labels = problem_type["num_labels"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user