Update quality tooling for formatting (#21480)
* Result of black 23.1 * Update target to Python 3.7 * Switch flake8 to ruff * Configure isort * Configure isort * Apply isort with line limit * Put the right black version * adapt black in check copies * Fix copies
This commit is contained in:
@@ -48,10 +48,10 @@ TOLERANCE = 1e-3
|
||||
|
||||
T5_SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||
if is_torch_available() and is_datasets_available() and is_faiss_available():
|
||||
import faiss
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
import faiss
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
@@ -99,7 +99,6 @@ def require_retrieval(test_case):
|
||||
@require_retrieval
|
||||
@require_sentencepiece
|
||||
class RagTestMixin:
|
||||
|
||||
all_model_classes = (
|
||||
(RagModel, RagTokenForGeneration, RagSequenceForGeneration)
|
||||
if is_torch_available() and is_datasets_available() and is_faiss_available()
|
||||
@@ -493,7 +492,7 @@ class RagTestMixin:
|
||||
decoder_attention_mask,
|
||||
retriever_n_docs,
|
||||
generator_n_docs,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
@@ -16,9 +16,9 @@ from transformers.utils import cached_property, is_datasets_available, is_faiss_
|
||||
|
||||
|
||||
if is_tf_available() and is_datasets_available() and is_faiss_available():
|
||||
import faiss
|
||||
import tensorflow as tf
|
||||
from datasets import Dataset
|
||||
import faiss
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -31,7 +31,6 @@ if is_tf_available() and is_datasets_available() and is_faiss_available():
|
||||
TFRagSequenceForGeneration,
|
||||
TFRagTokenForGeneration,
|
||||
)
|
||||
|
||||
from transformers.modeling_tf_outputs import TFBaseModelOutput
|
||||
|
||||
from ..bart.test_modeling_tf_bart import TFBartModelTester
|
||||
@@ -58,7 +57,6 @@ def require_retrieval(test_case):
|
||||
@require_retrieval
|
||||
@require_sentencepiece
|
||||
class TFRagTestMixin:
|
||||
|
||||
all_model_classes = (
|
||||
(TFRagModel, TFRagTokenForGeneration, TFRagSequenceForGeneration)
|
||||
if is_tf_available() and is_datasets_available() and is_faiss_available()
|
||||
@@ -392,7 +390,7 @@ class TFRagTestMixin:
|
||||
decoder_attention_mask,
|
||||
retriever_n_docs,
|
||||
generator_n_docs,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.assertIsNotNone(config.question_encoder)
|
||||
self.assertIsNotNone(config.generator)
|
||||
|
||||
@@ -360,7 +360,6 @@ class RagRetrieverTest(TestCase):
|
||||
@require_tokenizers
|
||||
@require_sentencepiece
|
||||
def test_custom_hf_index_end2end_retriever_call(self):
|
||||
|
||||
context_encoder_tokenizer = self.get_dpr_ctx_encoder_tokenizer()
|
||||
n_docs = 1
|
||||
retriever = self.get_dummy_custom_hf_index_retriever(from_disk=False)
|
||||
|
||||
@@ -110,7 +110,6 @@ class RagTokenizerTest(TestCase):
|
||||
|
||||
@require_tokenizers
|
||||
def test_save_load_pretrained_with_saved_config(self):
|
||||
|
||||
save_dir = os.path.join(self.tmpdirname, "rag_tokenizer")
|
||||
rag_config = RagConfig(question_encoder=DPRConfig().to_dict(), generator=BartConfig().to_dict())
|
||||
rag_tokenizer = RagTokenizer(question_encoder=self.get_dpr_tokenizer(), generator=self.get_bart_tokenizer())
|
||||
|
||||
Reference in New Issue
Block a user