[examples tests on multigpu] resolving require_torch_non_multi_gpu_but_fix_me (#10561)

* batch 1

* this is tpu

* deebert attempt

* the rest
This commit is contained in:
Stas Bekman
2021-03-08 11:11:40 -08:00
committed by GitHub
parent dfd16af832
commit f284089ec4
9 changed files with 35 additions and 62 deletions

View File

@@ -17,7 +17,7 @@ from transformers.integrations import is_ray_available
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
from transformers.testing_utils import require_ray
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
@@ -265,7 +265,6 @@ class RagRetrieverTest(TestCase):
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
@require_torch_non_multi_gpu_but_fix_me
def test_pytorch_distributed_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
@@ -276,7 +275,6 @@ class RagRetrieverTest(TestCase):
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
)
@require_torch_non_multi_gpu_but_fix_me
def test_custom_hf_index_pytorch_retriever_retrieve(self):
n_docs = 1
hidden_states = np.array(
@@ -289,7 +287,6 @@ class RagRetrieverTest(TestCase):
n_docs,
)
@require_torch_non_multi_gpu_but_fix_me
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
n_docs = 1
hidden_states = np.array(