Remove datasets requirement (#14795)
This commit is contained in:
@@ -31,7 +31,6 @@ from transformers import logging as transformers_logging
|
|||||||
|
|
||||||
from .deepspeed import is_deepspeed_available
|
from .deepspeed import is_deepspeed_available
|
||||||
from .file_utils import (
|
from .file_utils import (
|
||||||
is_datasets_available,
|
|
||||||
is_detectron2_available,
|
is_detectron2_available,
|
||||||
is_faiss_available,
|
is_faiss_available,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
@@ -513,15 +512,6 @@ def require_torch_tf32(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def require_datasets(test_case):
|
|
||||||
"""Decorator marking a test that requires datasets."""
|
|
||||||
|
|
||||||
if not is_datasets_available():
|
|
||||||
return unittest.skip("test requires `datasets`")(test_case)
|
|
||||||
else:
|
|
||||||
return test_case
|
|
||||||
|
|
||||||
|
|
||||||
def require_detectron2(test_case):
|
def require_detectron2(test_case):
|
||||||
"""Decorator marking a test that requires detectron2."""
|
"""Decorator marking a test that requires detectron2."""
|
||||||
if not is_detectron2_available():
|
if not is_detectron2_available():
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from transformers import Wav2Vec2Config, is_flax_available
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_librosa_available,
|
is_librosa_available,
|
||||||
is_pyctcdecode_available,
|
is_pyctcdecode_available,
|
||||||
require_datasets,
|
|
||||||
require_flax,
|
require_flax,
|
||||||
require_librosa,
|
require_librosa,
|
||||||
require_pyctcdecode,
|
require_pyctcdecode,
|
||||||
@@ -367,7 +366,6 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
@slow
|
@slow
|
||||||
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
class FlaxWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import pytest
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import HubertConfig, is_torch_available
|
from transformers import HubertConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
@@ -606,7 +606,6 @@ class HubertUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
@slow
|
@slow
|
||||||
class HubertModelIntegrationTest(unittest.TestCase):
|
class HubertModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import pytest
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import SEWConfig, is_torch_available
|
from transformers import SEWConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
@@ -462,7 +462,6 @@ class SEWUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
@slow
|
@slow
|
||||||
class SEWModelIntegrationTest(unittest.TestCase):
|
class SEWModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import pytest
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import SEWDConfig, is_torch_available
|
from transformers import SEWDConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
@@ -475,7 +475,6 @@ class SEWDUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
@slow
|
@slow
|
||||||
class SEWDModelIntegrationTest(unittest.TestCase):
|
class SEWDModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from transformers import is_tf_available
|
from transformers import is_tf_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_tf, slow
|
from transformers.testing_utils import require_soundfile, require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
@@ -473,7 +473,6 @@ class TFHubertUtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@slow
|
@slow
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
class TFHubertModelIntegrationTest(unittest.TestCase):
|
class TFHubertModelIntegrationTest(unittest.TestCase):
|
||||||
def _load_datasamples(self, num_samples):
|
def _load_datasamples(self, num_samples):
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from transformers import Wav2Vec2Config, is_tf_available
|
from transformers import Wav2Vec2Config, is_tf_available
|
||||||
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
|
from transformers.file_utils import is_librosa_available, is_pyctcdecode_available
|
||||||
from transformers.testing_utils import require_datasets, require_librosa, require_pyctcdecode, require_tf, slow
|
from transformers.testing_utils import require_librosa, require_pyctcdecode, require_tf, slow
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||||
@@ -483,7 +483,6 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
@require_tf
|
@require_tf
|
||||||
@slow
|
@slow
|
||||||
@require_datasets
|
|
||||||
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
class TFWav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||||
def _load_datasamples(self, num_samples):
|
def _load_datasamples(self, num_samples):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import UniSpeechConfig, is_torch_available
|
from transformers import UniSpeechConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
@@ -525,7 +525,6 @@ class UniSpeechRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
@slow
|
@slow
|
||||||
class UniSpeechModelIntegrationTest(unittest.TestCase):
|
class UniSpeechModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from datasets import load_dataset
|
|||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask
|
||||||
from transformers import UniSpeechSatConfig, is_torch_available
|
from transformers import UniSpeechSatConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_soundfile, require_torch, slow, torch_device
|
||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
from .test_modeling_common import ModelTesterMixin, _config_zero_init
|
||||||
@@ -783,7 +783,6 @@ class UniSpeechSatRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
@slow
|
@slow
|
||||||
class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
class UniSpeechSatModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from transformers.testing_utils import (
|
|||||||
is_pt_flax_cross_test,
|
is_pt_flax_cross_test,
|
||||||
is_pyctcdecode_available,
|
is_pyctcdecode_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
require_datasets,
|
|
||||||
require_pyctcdecode,
|
require_pyctcdecode,
|
||||||
require_soundfile,
|
require_soundfile,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -1060,7 +1059,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
@require_soundfile
|
@require_soundfile
|
||||||
@slow
|
@slow
|
||||||
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from transformers.pipelines import AudioClassificationPipeline, pipeline
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_datasets,
|
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_torchaudio,
|
require_torchaudio,
|
||||||
@@ -65,7 +64,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
|
|
||||||
self.run_torchaudio(audio_classifier)
|
self.run_torchaudio(audio_classifier)
|
||||||
|
|
||||||
@require_datasets
|
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
def run_torchaudio(self, audio_classifier):
|
def run_torchaudio(self, audio_classifier):
|
||||||
import datasets
|
import datasets
|
||||||
@@ -101,7 +99,6 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
)
|
)
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
@slow
|
@slow
|
||||||
def test_large_model_pt(self):
|
def test_large_model_pt(self):
|
||||||
import datasets
|
import datasets
|
||||||
|
|||||||
@@ -26,14 +26,7 @@ from transformers import (
|
|||||||
Wav2Vec2ForCTC,
|
Wav2Vec2ForCTC,
|
||||||
)
|
)
|
||||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import is_pipeline_test, require_tf, require_torch, require_torchaudio, slow
|
||||||
is_pipeline_test,
|
|
||||||
require_datasets,
|
|
||||||
require_tf,
|
|
||||||
require_torch,
|
|
||||||
require_torchaudio,
|
|
||||||
slow,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||||
|
|
||||||
@@ -105,7 +98,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
framework="pt",
|
framework="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_datasets
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_large(self):
|
def test_torch_large(self):
|
||||||
@@ -128,7 +120,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
output = speech_recognizer(filename)
|
output = speech_recognizer(filename)
|
||||||
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
self.assertEqual(output, {"text": "A MAN SAID TO THE UNIVERSE SIR I EXIST"})
|
||||||
|
|
||||||
@require_datasets
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@slow
|
@slow
|
||||||
def test_torch_speech_encoder_decoder(self):
|
def test_torch_speech_encoder_decoder(self):
|
||||||
@@ -148,7 +139,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
def test_simple_wav2vec2(self):
|
def test_simple_wav2vec2(self):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@@ -177,7 +167,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@require_datasets
|
|
||||||
def test_simple_s2t(self):
|
def test_simple_s2t(self):
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
@@ -207,7 +196,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@require_datasets
|
|
||||||
def test_xls_r_to_en(self):
|
def test_xls_r_to_en(self):
|
||||||
speech_recognizer = pipeline(
|
speech_recognizer = pipeline(
|
||||||
task="automatic-speech-recognition",
|
task="automatic-speech-recognition",
|
||||||
@@ -226,7 +214,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
|||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_torchaudio
|
@require_torchaudio
|
||||||
@require_datasets
|
|
||||||
def test_xls_r_from_en(self):
|
def test_xls_r_from_en(self):
|
||||||
speech_recognizer = pipeline(
|
speech_recognizer = pipeline(
|
||||||
task="automatic-speech-recognition",
|
task="automatic-speech-recognition",
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ from transformers.pipelines import ImageClassificationPipeline, pipeline
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_datasets,
|
|
||||||
require_tf,
|
require_tf,
|
||||||
require_torch,
|
require_torch,
|
||||||
require_vision,
|
require_vision,
|
||||||
@@ -53,7 +52,6 @@ class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
|
|||||||
]
|
]
|
||||||
return image_classifier, examples
|
return image_classifier, examples
|
||||||
|
|
||||||
@require_datasets
|
|
||||||
def run_pipeline_test(self, image_classifier, examples):
|
def run_pipeline_test(self, image_classifier, examples):
|
||||||
outputs = image_classifier("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
outputs = image_classifier("./tests/fixtures/tests_samples/COCO/000000039769.png")
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from transformers import (
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_datasets,
|
|
||||||
require_tf,
|
require_tf,
|
||||||
require_timm,
|
require_timm,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -61,7 +60,6 @@ class ImageSegmentationPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
|
|||||||
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
"./tests/fixtures/tests_samples/COCO/000000039769.png",
|
||||||
]
|
]
|
||||||
|
|
||||||
@require_datasets
|
|
||||||
def run_pipeline_test(self, image_segmenter, examples):
|
def run_pipeline_test(self, image_segmenter, examples):
|
||||||
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
outputs = image_segmenter("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||||
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
|
self.assertEqual(outputs, [{"score": ANY(float), "label": ANY(str), "mask": ANY(str)}] * 12)
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from transformers import (
|
|||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
is_pipeline_test,
|
is_pipeline_test,
|
||||||
nested_simplify,
|
nested_simplify,
|
||||||
require_datasets,
|
|
||||||
require_tf,
|
require_tf,
|
||||||
require_timm,
|
require_timm,
|
||||||
require_torch,
|
require_torch,
|
||||||
@@ -57,7 +56,6 @@ class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCase
|
|||||||
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
|
||||||
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
|
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
|
||||||
|
|
||||||
@require_datasets
|
|
||||||
def run_pipeline_test(self, object_detector, examples):
|
def run_pipeline_test(self, object_detector, examples):
|
||||||
outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
|
||||||
|
|
||||||
|
|||||||
@@ -32,13 +32,7 @@ from transformers.models.dpr.tokenization_dpr import DPRContextEncoderTokenizer,
|
|||||||
from transformers.models.rag.configuration_rag import RagConfig
|
from transformers.models.rag.configuration_rag import RagConfig
|
||||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import require_faiss, require_sentencepiece, require_tokenizers, require_torch
|
||||||
require_datasets,
|
|
||||||
require_faiss,
|
|
||||||
require_sentencepiece,
|
|
||||||
require_tokenizers,
|
|
||||||
require_torch,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_faiss_available():
|
if is_faiss_available():
|
||||||
@@ -46,7 +40,6 @@ if is_faiss_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_faiss
|
@require_faiss
|
||||||
@require_datasets
|
|
||||||
class RagRetrieverTest(TestCase):
|
class RagRetrieverTest(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.tmpdirname = tempfile.mkdtemp()
|
self.tmpdirname = tempfile.mkdtemp()
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from transformers.models.bart.configuration_bart import BartConfig
|
|||||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||||
from transformers.models.dpr.configuration_dpr import DPRConfig
|
from transformers.models.dpr.configuration_dpr import DPRConfig
|
||||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||||
from transformers.testing_utils import require_datasets, require_faiss, require_tokenizers, require_torch, slow
|
from transformers.testing_utils import require_faiss, require_tokenizers, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available() and is_datasets_available() and is_faiss_available():
|
if is_torch_available() and is_datasets_available() and is_faiss_available():
|
||||||
@@ -33,7 +33,6 @@ if is_torch_available() and is_datasets_available() and is_faiss_available():
|
|||||||
|
|
||||||
|
|
||||||
@require_faiss
|
@require_faiss
|
||||||
@require_datasets
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class RagTokenizerTest(TestCase):
|
class RagTokenizerTest(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
|
|||||||
@@ -46,7 +46,6 @@ from transformers.testing_utils import (
|
|||||||
get_gpu_count,
|
get_gpu_count,
|
||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
is_staging_test,
|
is_staging_test,
|
||||||
require_datasets,
|
|
||||||
require_optuna,
|
require_optuna,
|
||||||
require_ray,
|
require_ray,
|
||||||
require_sentencepiece,
|
require_sentencepiece,
|
||||||
@@ -391,7 +390,6 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model, alternate_seed=True)
|
self.check_trained_model(trainer.model, alternate_seed=True)
|
||||||
|
|
||||||
@require_datasets
|
|
||||||
def test_trainer_with_datasets(self):
|
def test_trainer_with_datasets(self):
|
||||||
import datasets
|
import datasets
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
from transformers import BertTokenizer, EncoderDecoderModel, Seq2SeqTrainer, Seq2SeqTrainingArguments
|
||||||
from transformers.file_utils import is_datasets_available
|
from transformers.file_utils import is_datasets_available
|
||||||
from transformers.testing_utils import TestCasePlus, require_datasets, require_torch, slow
|
from transformers.testing_utils import TestCasePlus, require_torch, slow
|
||||||
|
|
||||||
|
|
||||||
if is_datasets_available():
|
if is_datasets_available():
|
||||||
@@ -25,7 +25,6 @@ if is_datasets_available():
|
|||||||
class Seq2seqTrainerTester(TestCasePlus):
|
class Seq2seqTrainerTester(TestCasePlus):
|
||||||
@slow
|
@slow
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
|
||||||
def test_finetune_bert2bert(self):
|
def test_finetune_bert2bert(self):
|
||||||
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("prajjwal1/bert-tiny", "prajjwal1/bert-tiny")
|
||||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|||||||
Reference in New Issue
Block a user