From 74297d0a55f9854e9a4e635b9c41a11d700d0a79 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 21 Nov 2022 15:36:49 +0100 Subject: [PATCH] [Switch Transformers] Fix failing slow test (#20346) * run slow test on GPU * remove unnecessary device assignment * use `torch_device` instead --- .../test_modeling_switch_transformers.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/models/switch_transformers/test_modeling_switch_transformers.py b/tests/models/switch_transformers/test_modeling_switch_transformers.py index 0b90fc3713..1afeb2e484 100644 --- a/tests/models/switch_transformers/test_modeling_switch_transformers.py +++ b/tests/models/switch_transformers/test_modeling_switch_transformers.py @@ -19,7 +19,7 @@ import tempfile import unittest from transformers import SwitchTransformersConfig, is_torch_available -from transformers.testing_utils import require_tokenizers, require_torch, slow, torch_device +from transformers.testing_utils import require_tokenizers, require_torch, require_torch_gpu, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -1104,15 +1104,18 @@ class SwitchTransformerRouterTest(unittest.TestCase): @require_torch @require_tokenizers class SwitchTransformerModelIntegrationTests(unittest.TestCase): + @require_torch_gpu def test_small_logits(self): r""" Logits testing to check implementation consistency between `t5x` implementation and `transformers` implementation of Switch-C transformers. We only check the logits of the first batch. """ - model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).eval() - input_ids = torch.ones((32, 64), dtype=torch.long) - decoder_input_ids = torch.ones((32, 64), dtype=torch.long) + model = SwitchTransformersModel.from_pretrained("google/switch-base-8", torch_dtype=torch.bfloat16).to( + torch_device + ) + input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device) + decoder_input_ids = torch.ones((32, 64), dtype=torch.long).to(torch_device) # fmt: off EXPECTED_MEAN_LOGITS = torch.Tensor( @@ -1126,8 +1129,7 @@ class SwitchTransformerModelIntegrationTests(unittest.TestCase): ] ).to(torch.bfloat16) # fmt: on - - hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state + hf_logits = model(input_ids, decoder_input_ids=decoder_input_ids).last_hidden_state.cpu() hf_logits = hf_logits[0, 0, :30] torch.testing.assert_allclose(hf_logits, EXPECTED_MEAN_LOGITS, rtol=6e-3, atol=9e-3)