From 447490015a1abd933bf237e0bf5abfead51ebb22 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 24 Jun 2022 16:26:14 +0200 Subject: [PATCH] Fix Splinter test (#17854) * fix Co-authored-by: ydshieh --- .../models/splinter/test_modeling_splinter.py | 38 ++++++++++++++++++- 1 file changed, 37 insertions(+), 1 deletion(-) diff --git a/tests/models/splinter/test_modeling_splinter.py b/tests/models/splinter/test_modeling_splinter.py index bc355bd2cd..f064611b6a 100644 --- a/tests/models/splinter/test_modeling_splinter.py +++ b/tests/models/splinter/test_modeling_splinter.py @@ -18,7 +18,7 @@ import copy import unittest from transformers import is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -316,6 +316,42 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase): model = SplinterModel.from_pretrained(model_name) self.assertIsNotNone(model) + # overwrite from common since `SplinterForPreTraining` could contain different number of question tokens in inputs. + # When the batch is distributed to multiple devices, each replica could get different values for the maximal number + # of question tokens (see `SplinterForPreTraining._prepare_question_positions()`), and the model returns different + # shape along dimension 1 (i.e. `num_questions`) that could not be combined into a single tensor as an output. + @require_torch_multi_gpu + def test_multi_gpu_data_parallel_forward(self): + from torch import nn + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # some params shouldn't be scattered by nn.DataParallel + # so just remove them if they are present. + blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"] + for k in blacklist_non_batched_params: + inputs_dict.pop(k, None) + + # move input tensors to cuda:O + for k, v in inputs_dict.items(): + if torch.is_tensor(v): + inputs_dict[k] = v.to(0) + + for model_class in self.all_model_classes: + + # Skip this case since it will fail sometimes, as described above. + if model_class == SplinterForPreTraining: + continue + + model = model_class(config=config) + model.to(0) + model.eval() + + # Wrap model in nn.DataParallel + model = nn.DataParallel(model) + with torch.no_grad(): + _ = model(**self._prepare_for_class(inputs_dict, model_class)) + @require_torch class SplinterModelIntegrationTest(unittest.TestCase):