[Switch Transformers] Fix failing slow test (#20346)
* run slow test on GPU * remove unnecessary device assignment * use `torch_device` instead
This commit is contained in:
@@ -19,7 +19,7 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SwitchTransformersConfig, is_torch_available
|
from transformers import SwitchTransformersConfig, is_torch_available
|
||||||
from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_tokenizers, require_torch, require_torch_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...generation.test_utils import GenerationTesterMixin
|
from ...generation.test_utils import GenerationTesterMixin
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
@@ -1104,15 +1104,18 @@ class SwitchTransformerRouterTest(unittest.TestCase):
|
|||||||
@require_torch
|
@require_torch
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||||
|
@require_torch_gpu
|
||||||
def test_small_logits(self):
|
def test_small_logits(self):
|
||||||
r"""
|
r"""
|
||||||
Logits testing to check implementation consistency between `t5x` implementation
|
Logits testing to check implementation consistency between `t5x` implementation
|
||||||
and `transformers` implementation of Switch-C transformers. We only check the logits
|
and `transformers` implementation of Switch-C transformers. We only check the logits
|
||||||
of the first batch.
|
of the first batch.
|
||||||
"""
|
"""
|
||||||
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).eval()
|
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to(
|
||||||
input_ids = torch.ones((32, 64), dtype=torch.long)
|
torch_device
|
||||||
decoder_input_ids = torch.ones((32, 64), dtype=torch.long)
|
)
|
||||||
|
input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
|
||||||
|
decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
EXPECTED_MEAN_LOGITS = torch.Tensor(
|
EXPECTED_MEAN_LOGITS = torch.Tensor(
|
||||||
@@ -1126,8 +1129,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
).to(torch.bfloat16)
|
).to(torch.bfloat16)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
|
||||||
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
|
|
||||||
hf_logits = hf_logits[0, 0, :30]
|
hf_logits = hf_logits[0, 0, :30]
|
||||||
|
|
||||||
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
|
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
|
||||||
|
|||||||
Reference in New Issue
Block a user