[examples tests on multigpu] resolving require_torch_non_multi_gpu_but_fix_me (#10561)
* batch 1 * this is tpu * deebert attempt * the rest
This commit is contained in:
@@ -17,7 +17,7 @@ from transformers.integrations import is_ray_available
|
||||
from transformers.models.bert.tokenization_bert import VOCAB_FILES_NAMES as DPR_VOCAB_FILES_NAMES
|
||||
from transformers.models.rag.retrieval_rag import CustomHFIndex, RagRetriever
|
||||
from transformers.models.roberta.tokenization_roberta import VOCAB_FILES_NAMES as BART_VOCAB_FILES_NAMES
|
||||
from transformers.testing_utils import require_ray, require_torch_non_multi_gpu_but_fix_me
|
||||
from transformers.testing_utils import require_ray
|
||||
|
||||
|
||||
sys.path.append(os.path.join(os.getcwd())) # noqa: E402 # noqa: E402 # isort:skip
|
||||
@@ -265,7 +265,6 @@ class RagRetrieverTest(TestCase):
|
||||
self.assertEqual(doc_dicts[1]["id"][0], "0") # max inner product is reached with first doc
|
||||
self.assertListEqual(doc_ids.tolist(), [[1], [0]])
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_pytorch_distributed_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
@@ -276,7 +275,6 @@ class RagRetrieverTest(TestCase):
|
||||
self.get_dummy_pytorch_distributed_retriever(init_retrieval=True), hidden_states, n_docs
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_hf_index_pytorch_retriever_retrieve(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
@@ -289,7 +287,6 @@ class RagRetrieverTest(TestCase):
|
||||
n_docs,
|
||||
)
|
||||
|
||||
@require_torch_non_multi_gpu_but_fix_me
|
||||
def test_custom_pytorch_distributed_retriever_retrieve_from_disk(self):
|
||||
n_docs = 1
|
||||
hidden_states = np.array(
|
||||
|
||||
Reference in New Issue
Block a user