Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -5,6 +5,7 @@ import sys
|
||||
from pathlib import Path
|
||||
|
||||
import finetune_rag
|
||||
|
||||
from transformers.file_utils import is_apex_available
|
||||
from transformers.testing_utils import (
|
||||
TestCasePlus,
|
||||
|
||||
@@ -6,7 +6,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.utilities import rank_zero_only
|
||||
|
||||
from utils_rag import save_json
|
||||
|
||||
|
||||
|
||||
@@ -17,7 +17,6 @@ def consolidate(
|
||||
generator_tokenizer_name_or_path: str = None,
|
||||
question_encoder_tokenizer_name_or_path: str = None,
|
||||
):
|
||||
|
||||
if config_name_or_path is None:
|
||||
config_name_or_path = "facebook/rag-token-base" if model_type == "rag_token" else "facebook/rag-sequence-base"
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import random
|
||||
|
||||
import ray
|
||||
|
||||
from transformers import RagConfig, RagRetriever, RagTokenizer
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class BaseTransformer(pl.LightningModule):
|
||||
config=None,
|
||||
tokenizer=None,
|
||||
model=None,
|
||||
**config_kwargs
|
||||
**config_kwargs,
|
||||
):
|
||||
"""Initialize a model, tokenizer and config."""
|
||||
super().__init__()
|
||||
@@ -356,7 +356,7 @@ def generic_train(
|
||||
extra_callbacks=[],
|
||||
checkpoint_callback=None,
|
||||
logging_callback=None,
|
||||
**extra_train_kwargs
|
||||
**extra_train_kwargs,
|
||||
):
|
||||
pl.seed_everything(args.seed)
|
||||
|
||||
|
||||
@@ -7,10 +7,10 @@ import unittest
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
import faiss
|
||||
import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers import BartConfig, BartTokenizer, DPRConfig, DPRQuestionEncoderTokenizer, RagConfig
|
||||
from transformers.file_utils import is_datasets_available, is_faiss_available, is_psutil_available, is_torch_available
|
||||
from transformers.integrations import is_ray_available
|
||||
|
||||
@@ -6,10 +6,10 @@ from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import List, Optional
|
||||
|
||||
import faiss
|
||||
import torch
|
||||
from datasets import Features, Sequence, Value, load_dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
DPRContextEncoder,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
@@ -56,7 +56,6 @@ def main(
|
||||
processing_args: "ProcessingArguments",
|
||||
index_hnsw_args: "IndexHnswArguments",
|
||||
):
|
||||
|
||||
######################################
|
||||
logger.info("Step 1 - Create the dataset")
|
||||
######################################
|
||||
|
||||
Reference in New Issue
Block a user