Update quality tooling for formatting (#21480)

* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
This commit is contained in:
Sylvain Gugger
2023-02-06 18:10:56 -05:00
committed by GitHub
parent b7bb2b59f7
commit 6f79d26442
1211 changed files with 1532 additions and 2687 deletions

View File

@@ -4,8 +4,6 @@ from dataclasses import dataclass
from functools import partial
from typing import Callable
from tqdm.auto import tqdm
import flax.linen as nn
import jax
import jax.numpy as jnp
@@ -16,6 +14,8 @@ from flax import jax_utils, struct, traverse_util
from flax.serialization import from_bytes, to_bytes
from flax.training import train_state
from flax.training.common_utils import shard
from tqdm.auto import tqdm
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
@@ -98,7 +98,6 @@ class Args:
@dataclass
class DataCollator:
pad_id: int
max_length: int = 4096 # no dynamic padding on TPUs

View File

@@ -1,8 +1,8 @@
from datasets import load_from_disk
import jax
import jax.numpy as jnp
from bigbird_flax import FlaxBigBirdForNaturalQuestions
from datasets import load_from_disk
from transformers import BigBirdTokenizerFast

View File

@@ -1,10 +1,9 @@
import os
import jsonlines
import numpy as np
from tqdm import tqdm
import jsonlines
DOC_STRIDE = 2048
MAX_LENGTH = 4096

View File

@@ -1,12 +1,12 @@
import os
from dataclasses import replace
from datasets import load_dataset
import jax
import wandb
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
from datasets import load_dataset
from flax import jax_utils
from transformers import BigBirdTokenizerFast

View File

@@ -32,17 +32,17 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@@ -20,6 +20,7 @@ import jax
import jax.numpy as jnp
from configuration_hybrid_clip import HybridCLIPConfig
from flax.core.frozen_dict import FrozenDict
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
from transformers.modeling_flax_utils import FlaxPreTrainedModel
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
@@ -132,7 +133,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
input_shape: Optional[Tuple] = None,
seed: int = 0,
dtype: jnp.dtype = jnp.float32,
**kwargs
**kwargs,
):
if input_shape is None:
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))

View File

@@ -32,22 +32,22 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Callable, Optional
import jax
import jax.numpy as jnp
import optax
import torch
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, shard, shard_prng_key
from modeling_hybrid_clip import FlaxHybridCLIP
from torchvision.datasets import VisionDataset
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
from torchvision.transforms.functional import InterpolationMode
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, shard, shard_prng_key
from modeling_hybrid_clip import FlaxHybridCLIP
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed

View File

@@ -28,19 +28,19 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from flax.core.frozen_dict import freeze, unfreeze
from flax.training.common_utils import onehot, stack_forest
from jax.experimental.maps import mesh
from jax.experimental.pjit import pjit
from partitions import set_partitions
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,

View File

@@ -6,18 +6,18 @@ from dataclasses import field
from pathlib import Path
from typing import Dict, List, Optional, Union
import numpy as np
from datasets import DatasetDict, load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import librosa
import numpy as np
import optax
from datasets import DatasetDict, load_dataset
from flax import jax_utils, traverse_util
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from tqdm import tqdm
from transformers import (
FlaxWav2Vec2ForPreTraining,
HfArgumentParser,