[Switch Transformers] Fix failing slow test (#20346)
* run slow test on GPU * remove unnecessary device assignment * use `torch_device` instead
This commit is contained in:
@@ -19,7 +19,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import SwitchTransformersConfig, is_torch_available
|
||||
from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device
|
||||
from transformers.testing_utils import require_tokenizers, require_torch, require_torch_gpu, slow, torch_device
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
@@ -1104,15 +1104,18 @@ class SwitchTransformerRouterTest(unittest.TestCase):
|
||||
@require_torch
|
||||
@require_tokenizers
|
||||
class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||
@require_torch_gpu
|
||||
def test_small_logits(self):
|
||||
r"""
|
||||
Logits testing to check implementation consistency between `t5x` implementation
|
||||
and `transformers` implementation of Switch-C transformers. We only check the logits
|
||||
of the first batch.
|
||||
"""
|
||||
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).eval()
|
||||
input_ids = torch.ones((32, 64), dtype=torch.long)
|
||||
decoder_input_ids = torch.ones((32, 64), dtype=torch.long)
|
||||
model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
|
||||
decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device)
|
||||
|
||||
# fmt: off
|
||||
EXPECTED_MEAN_LOGITS = torch.Tensor(
|
||||
@@ -1126,8 +1129,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase):
|
||||
]
|
||||
).to(torch.bfloat16)
|
||||
# fmt: on
|
||||
|
||||
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state
|
||||
hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu()
|
||||
hf_logits = hf_logits[0, 0, :30]
|
||||
|
||||
torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)
|
||||
|
||||
Reference in New Issue
Block a user