Update quality tooling for formatting (#21480)

* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
This commit is contained in:
Sylvain Gugger
2023-02-06 18:10:56 -05:00
committed by GitHub
parent b7bb2b59f7
commit 6f79d26442
1211 changed files with 1532 additions and 2687 deletions

View File

@@ -15,13 +15,13 @@
from typing import Callable, Dict, Tuple
import numpy as np
import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
from modeling_flax_performer_utils import make_fast_softmax_attention
from transformers.file_utils import add_start_docstrings
from transformers.modeling_flax_utils import ACT2FN
from transformers.models.bert.configuration_bert import BertConfig
@@ -366,7 +366,6 @@ class FlaxPerformerModel(FlaxBertPreTrainedModel):
# SelfAttention needs also to replace "weight" by "kernel"
if {"query", "key", "value"} & key_parts:
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
if "bias" in key:
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
@@ -443,7 +442,6 @@ class FlaxPerformerModel(FlaxBertPreTrainedModel):
def __call__(
self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None
):
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
input_ids, attention_mask, token_type_ids, position_ids
)

View File

@@ -30,11 +30,10 @@ import abc
import functools
from collections.abc import Iterable # pylint: disable=g-importing-member
import numpy as onp
from absl import logging
import jax
import jax.numpy as jnp
import numpy as onp
from absl import logging
from jax import lax, random
@@ -524,7 +523,6 @@ class FastAttentionviaLowRankDecomposition(FastAttention):
deterministic=False,
precision=None,
):
assert key.shape[:-1] == value.shape[:-1]
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]
if axis is None:

View File

@@ -28,18 +28,18 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
from datasets import load_dataset
from flax import jax_utils
from flax.optim import Adam
from flax.training import common_utils
from flax.training.common_utils import get_metrics
from jax.nn import log_softmax
from modeling_flax_performer import FlaxPerformerForMaskedLM
from tqdm import tqdm
from transformers import (
MODEL_FOR_MASKED_LM_MAPPING,
AutoTokenizer,
@@ -632,7 +632,6 @@ if __name__ == "__main__":
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
for epoch in epochs:
# ======================== Training ================================
# Create sampling rng
rng, training_rng, eval_rng = jax.random.split(rng, 3)