Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -15,13 +15,13 @@
|
||||
|
||||
from typing import Callable, Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from jax.random import PRNGKey
|
||||
from modeling_flax_performer_utils import make_fast_softmax_attention
|
||||
|
||||
from transformers.file_utils import add_start_docstrings
|
||||
from transformers.modeling_flax_utils import ACT2FN
|
||||
from transformers.models.bert.configuration_bert import BertConfig
|
||||
@@ -366,7 +366,6 @@ class FlaxPerformerModel(FlaxBertPreTrainedModel):
|
||||
|
||||
# SelfAttention needs also to replace "weight" by "kernel"
|
||||
if {"query", "key", "value"} & key_parts:
|
||||
|
||||
# Flax SelfAttention decomposes the heads (num_head, size // num_heads)
|
||||
if "bias" in key:
|
||||
jax_state[key] = tensor.reshape((config.num_attention_heads, -1))
|
||||
@@ -443,7 +442,6 @@ class FlaxPerformerModel(FlaxBertPreTrainedModel):
|
||||
def __call__(
|
||||
self, input_ids, token_type_ids=None, position_ids=None, dropout_rng: PRNGKey = None, attention_mask=None
|
||||
):
|
||||
|
||||
input_ids, attention_mask, token_type_ids, position_ids = self._check_inputs(
|
||||
input_ids, attention_mask, token_type_ids, position_ids
|
||||
)
|
||||
|
||||
@@ -30,11 +30,10 @@ import abc
|
||||
import functools
|
||||
from collections.abc import Iterable # pylint: disable=g-importing-member
|
||||
|
||||
import numpy as onp
|
||||
from absl import logging
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as onp
|
||||
from absl import logging
|
||||
from jax import lax, random
|
||||
|
||||
|
||||
@@ -524,7 +523,6 @@ class FastAttentionviaLowRankDecomposition(FastAttention):
|
||||
deterministic=False,
|
||||
precision=None,
|
||||
):
|
||||
|
||||
assert key.shape[:-1] == value.shape[:-1]
|
||||
assert query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]
|
||||
if axis is None:
|
||||
|
||||
@@ -28,18 +28,18 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils
|
||||
from flax.optim import Adam
|
||||
from flax.training import common_utils
|
||||
from flax.training.common_utils import get_metrics
|
||||
from jax.nn import log_softmax
|
||||
from modeling_flax_performer import FlaxPerformerForMaskedLM
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
AutoTokenizer,
|
||||
@@ -632,7 +632,6 @@ if __name__ == "__main__":
|
||||
|
||||
epochs = tqdm(range(nb_epochs), desc=f"Epoch ... (1/{nb_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
# ======================== Training ================================
|
||||
# Create sampling rng
|
||||
rng, training_rng, eval_rng = jax.random.split(rng, 3)
|
||||
|
||||
Reference in New Issue
Block a user