Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -4,8 +4,6 @@ from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Callable
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import flax.linen as nn
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@@ -16,6 +14,8 @@ from flax import jax_utils, struct, traverse_util
|
||||
from flax.serialization import from_bytes, to_bytes
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from transformers import BigBirdConfig, FlaxBigBirdForQuestionAnswering
|
||||
from transformers.models.big_bird.modeling_flax_big_bird import FlaxBigBirdForQuestionAnsweringModule
|
||||
|
||||
@@ -98,7 +98,6 @@ class Args:
|
||||
|
||||
@dataclass
|
||||
class DataCollator:
|
||||
|
||||
pad_id: int
|
||||
max_length: int = 4096 # no dynamic padding on TPUs
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from datasets import load_from_disk
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from bigbird_flax import FlaxBigBirdForNaturalQuestions
|
||||
from datasets import load_from_disk
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
|
||||
import jsonlines
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import jsonlines
|
||||
|
||||
|
||||
DOC_STRIDE = 2048
|
||||
MAX_LENGTH = 4096
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import os
|
||||
from dataclasses import replace
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
import jax
|
||||
import wandb
|
||||
from bigbird_flax import Args, DataCollator, FlaxBigBirdForNaturalQuestions, Trainer, build_tx, train_step, val_step
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils
|
||||
|
||||
from transformers import BigBirdTokenizerFast
|
||||
|
||||
|
||||
|
||||
@@ -32,17 +32,17 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
|
||||
@@ -20,6 +20,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from configuration_hybrid_clip import HybridCLIPConfig
|
||||
from flax.core.frozen_dict import FrozenDict
|
||||
|
||||
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
||||
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
||||
@@ -132,7 +133,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
input_shape: Optional[Tuple] = None,
|
||||
seed: int = 0,
|
||||
dtype: jnp.dtype = jnp.float32,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if input_shape is None:
|
||||
input_shape = ((1, 1), (1, config.vision_config.image_size, config.vision_config.image_size, 3))
|
||||
|
||||
@@ -32,22 +32,22 @@ from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import torch
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
from torchvision.datasets import VisionDataset
|
||||
from torchvision.io import ImageReadMode, read_image
|
||||
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import optax
|
||||
import transformers
|
||||
from flax import jax_utils
|
||||
from flax.jax_utils import unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, is_tensorboard_available, set_seed
|
||||
|
||||
|
||||
|
||||
@@ -28,19 +28,19 @@ from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import Dataset, load_dataset
|
||||
from flax.core.frozen_dict import freeze, unfreeze
|
||||
from flax.training.common_utils import onehot, stack_forest
|
||||
from jax.experimental.maps import mesh
|
||||
from jax.experimental.pjit import pjit
|
||||
from partitions import set_partitions
|
||||
from tqdm import tqdm
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
|
||||
@@ -6,18 +6,18 @@ from dataclasses import field
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import librosa
|
||||
import numpy as np
|
||||
import optax
|
||||
from datasets import DatasetDict, load_dataset
|
||||
from flax import jax_utils, traverse_util
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
FlaxWav2Vec2ForPreTraining,
|
||||
HfArgumentParser,
|
||||
|
||||
Reference in New Issue
Block a user