Fix Splinter test (#17854)
* fix Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,7 @@ import copy
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_torch, slow, torch_device
|
from transformers.testing_utils import require_torch, require_torch_multi_gpu, slow, torch_device
|
||||||
|
|
||||||
from ...test_configuration_common import ConfigTester
|
from ...test_configuration_common import ConfigTester
|
||||||
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
|
||||||
@@ -316,6 +316,42 @@ class SplinterModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
model = SplinterModel.from_pretrained(model_name)
|
model = SplinterModel.from_pretrained(model_name)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
# overwrite from common since `SplinterForPreTraining` could contain different number of question tokens in inputs.
|
||||||
|
# When the batch is distributed to multiple devices, each replica could get different values for the maximal number
|
||||||
|
# of question tokens (see `SplinterForPreTraining._prepare_question_positions()`), and the model returns different
|
||||||
|
# shape along dimension 1 (i.e. `num_questions`) that could not be combined into a single tensor as an output.
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
def test_multi_gpu_data_parallel_forward(self):
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# some params shouldn't be scattered by nn.DataParallel
|
||||||
|
# so just remove them if they are present.
|
||||||
|
blacklist_non_batched_params = ["head_mask", "decoder_head_mask", "cross_attn_head_mask"]
|
||||||
|
for k in blacklist_non_batched_params:
|
||||||
|
inputs_dict.pop(k, None)
|
||||||
|
|
||||||
|
# move input tensors to cuda:O
|
||||||
|
for k, v in inputs_dict.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
inputs_dict[k] = v.to(0)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
|
||||||
|
# Skip this case since it will fail sometimes, as described above.
|
||||||
|
if model_class == SplinterForPreTraining:
|
||||||
|
continue
|
||||||
|
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.to(0)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Wrap model in nn.DataParallel
|
||||||
|
model = nn.DataParallel(model)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class SplinterModelIntegrationTest(unittest.TestCase):
|
class SplinterModelIntegrationTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user