Reorganize repo (#8580)
* Put models in subfolders * Styling * Fix imports in tests * More fixes in test imports * Sneaky hidden imports * Fix imports in doc files * More sneaky imports * Finish fixing tests * Fix examples * Fix path for copies * More fixes for examples * Fix dummy files * More fixes for example * More model import fixes * Is this why you're unhappy GitHub? * Fix imports in conver command
This commit is contained in:
@@ -10,10 +10,14 @@ import numpy as np
|
||||
from datasets import Dataset
|
||||
|
||||
from transformers import is_faiss_available
|
||||
from transformers.configuration_bart import BartConfig
|
||||
from transformers.configuration_dpr import DPRConfig
|
||||
from transformers.configuration_rag import RagConfig
|
||||
from transformers.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.bart.configuration_bart import BartConfig
|
||||
from transformers.models.bart.tokenization_bart import BartTokenizer
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.dpr.configuration_dpr import DPRConfig
|
||||
from transformers.models.dpr.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.models.rag.configuration_rag import RagConfig
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import (
|
||||
require_datasets,
|
||||
require_faiss,
|
||||
@@ -21,10 +25,6 @@ from transformers.testing_utils import (
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.tokenization_bart import BartTokenizer
|
||||
from transformers.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.tokenization_dpr import DPRQuestionEncoderTokenizer
|
||||
from transformers.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
if is_faiss_available():
|
||||
@@ -126,7 +126,7 @@ class RagRetrieverTest(TestCase):
|
||||
question_encoder=DPRConfig().to_dict(),
|
||||
generator=BartConfig().to_dict(),
|
||||
)
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = dataset
|
||||
retriever = RagRetriever(
|
||||
config,
|
||||
@@ -213,7 +213,7 @@ class RagRetrieverTest(TestCase):
|
||||
def test_canonical_hf_index_retriever_save_and_from_pretrained(self):
|
||||
retriever = self.get_dummy_canonical_hf_index_retriever()
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
with patch("transformers.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
with patch("transformers.models.rag.retrieval_rag.load_dataset") as mock_load_dataset:
|
||||
mock_load_dataset.return_value = self.get_dummy_dataset()
|
||||
retriever.save_pretrained(tmp_dirname)
|
||||
retriever = RagRetriever.from_pretrained(tmp_dirname)
|
||||
|
||||
Reference in New Issue
Block a user