Update quality tooling for formatting (#21480)

* Result of black 23.1

* Update target to Python 3.7

* Switch flake8 to ruff

* Configure isort

* Configure isort

* Apply isort with line limit

* Put the right black version

* adapt black in check copies

* Fix copies
This commit is contained in:
Sylvain Gugger
2023-02-06 18:10:56 -05:00
committed by GitHub
parent b7bb2b59f7
commit 6f79d26442
1211 changed files with 1532 additions and 2687 deletions

View File

@@ -29,23 +29,23 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset
from PIL import Image
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from PIL import Image
from tqdm import tqdm
import transformers
from transformers import (
AutoImageProcessor,
AutoTokenizer,

View File

@@ -32,20 +32,20 @@ from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional
import nltk
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import nltk
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@@ -34,19 +34,19 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_CAUSAL_LM_MAPPING,

View File

@@ -34,19 +34,19 @@ from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@@ -33,19 +33,19 @@ from itertools import chain
from pathlib import Path
from typing import Dict, List, Optional
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax
from datasets import load_dataset
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_MASKED_LM_MAPPING,

View File

@@ -31,20 +31,21 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
from utils_qa import postprocess_qa_predictions
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
@@ -55,7 +56,6 @@ from transformers import (
is_tensorboard_available,
)
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
from utils_qa import postprocess_qa_predictions
logger = logging.getLogger(__name__)
@@ -301,6 +301,7 @@ class DataTrainingArguments:
# endregion
# region Create a train state
def create_train_state(
model: FlaxAutoModelForQuestionAnswering,
@@ -387,6 +388,7 @@ def create_learning_rate_fn(
# endregion
# region train data iterator
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
@@ -405,6 +407,7 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
# endregion
# region eval data iterator
def eval_data_collator(dataset: Dataset, batch_size: int):
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
@@ -934,7 +937,6 @@ def main():
total_steps = step_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
@@ -975,7 +977,6 @@ def main():
and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0)
and cur_step > 0
):
eval_metrics = {}
all_start_logits = []
all_end_logits = []

View File

@@ -31,22 +31,22 @@ from pathlib import Path
from typing import Callable, Optional
import datasets
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
from datasets import Dataset, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import nltk # Here to have a nice missing dependency error message early on
import numpy as np
import optax
import transformers
from datasets import Dataset, load_dataset
from filelock import FileLock
from flax import jax_utils, traverse_util
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,

View File

@@ -26,20 +26,20 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
@@ -586,7 +586,6 @@ def main():
total_steps = steps_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (0/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
@@ -623,7 +622,6 @@ def main():
train_metrics = []
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
# evaluate
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
for batch in tqdm(

View File

@@ -28,20 +28,20 @@ from pathlib import Path
from typing import Any, Callable, Dict, Optional, Tuple
import datasets
import numpy as np
from datasets import ClassLabel, load_dataset
from tqdm import tqdm
import evaluate
import jax
import jax.numpy as jnp
import numpy as np
import optax
import transformers
from datasets import ClassLabel, load_dataset
from flax import struct, traverse_util
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
AutoConfig,
AutoTokenizer,
@@ -695,7 +695,6 @@ def main():
total_steps = step_per_epoch * num_epochs
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
for epoch in epochs:
train_start = time.time()
train_metrics = []
@@ -731,7 +730,6 @@ def main():
train_metrics = []
if cur_step % training_args.eval_steps == 0 and cur_step > 0:
eval_metrics = {}
# evaluate
for batch in tqdm(

View File

@@ -29,21 +29,22 @@ from enum import Enum
from pathlib import Path
from typing import Callable, Optional
import jax
import jax.numpy as jnp
import optax
# for dataset and preprocessing
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
import jax
import jax.numpy as jnp
import optax
import transformers
from flax import jax_utils
from flax.jax_utils import pad_shard_unpad, unreplicate
from flax.training import train_state
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
from huggingface_hub import Repository, create_repo
from tqdm import tqdm
import transformers
from transformers import (
CONFIG_MAPPING,
FLAX_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,