Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -31,20 +31,21 @@ from pathlib import Path
|
||||
from typing import Any, Callable, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
import evaluate
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import optax
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from flax import struct, traverse_util
|
||||
from flax.jax_utils import pad_shard_unpad, replicate, unreplicate
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import get_metrics, onehot, shard
|
||||
from huggingface_hub import Repository, create_repo
|
||||
from tqdm import tqdm
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
import transformers
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoTokenizer,
|
||||
@@ -55,7 +56,6 @@ from transformers import (
|
||||
is_tensorboard_available,
|
||||
)
|
||||
from transformers.utils import check_min_version, get_full_repo_name, send_example_telemetry
|
||||
from utils_qa import postprocess_qa_predictions
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -301,6 +301,7 @@ class DataTrainingArguments:
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region Create a train state
|
||||
def create_train_state(
|
||||
model: FlaxAutoModelForQuestionAnswering,
|
||||
@@ -387,6 +388,7 @@ def create_learning_rate_fn(
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region train data iterator
|
||||
def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
||||
"""Returns shuffled batches of size `batch_size` from truncated `train dataset`, sharded over all local devices."""
|
||||
@@ -405,6 +407,7 @@ def train_data_collator(rng: PRNGKey, dataset: Dataset, batch_size: int):
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region eval data iterator
|
||||
def eval_data_collator(dataset: Dataset, batch_size: int):
|
||||
"""Returns batches of size `batch_size` from `eval dataset`. Sharding handled by `pad_shard_unpad` in the eval loop."""
|
||||
@@ -934,7 +937,6 @@ def main():
|
||||
total_steps = step_per_epoch * num_epochs
|
||||
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
||||
for epoch in epochs:
|
||||
|
||||
train_start = time.time()
|
||||
train_metrics = []
|
||||
|
||||
@@ -975,7 +977,6 @@ def main():
|
||||
and (cur_step % training_args.eval_steps == 0 or cur_step % step_per_epoch == 0)
|
||||
and cur_step > 0
|
||||
):
|
||||
|
||||
eval_metrics = {}
|
||||
all_start_logits = []
|
||||
all_end_logits = []
|
||||
|
||||
Reference in New Issue
Block a user