From 12c1b5b8f448d652f5e1fa0f069b9569f4540948 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 19 Jan 2021 09:06:24 +0100 Subject: [PATCH] fix test (#9669) --- tests/test_modeling_bart.py | 4 ++-- tests/test_modeling_blenderbot.py | 4 ++-- tests/test_modeling_blenderbot_small.py | 4 ++-- tests/test_modeling_common.py | 2 +- tests/test_modeling_marian.py | 4 ++-- tests/test_modeling_mbart.py | 4 ++-- tests/test_modeling_pegasus.py | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index 7bf7bde03b..363122147d 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -61,9 +61,9 @@ def prepare_bart_inputs_dict( if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) if head_mask is None: - head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: - decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, diff --git a/tests/test_modeling_blenderbot.py b/tests/test_modeling_blenderbot.py index 826d43afff..b75e147d9a 100644 --- a/tests/test_modeling_blenderbot.py +++ b/tests/test_modeling_blenderbot.py @@ -48,9 +48,9 @@ def prepare_blenderbot_inputs_dict( if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) if head_mask is None: - head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: - decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, diff --git a/tests/test_modeling_blenderbot_small.py b/tests/test_modeling_blenderbot_small.py index abb223413a..f5cf5b1945 100644 --- a/tests/test_modeling_blenderbot_small.py +++ b/tests/test_modeling_blenderbot_small.py @@ -56,9 +56,9 @@ def prepare_blenderbot_small_inputs_dict( if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) if head_mask is None: - head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: - decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d14cab8e69..4a0cf5e1c5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1073,7 +1073,7 @@ class ModelTesterMixin: # some params shouldn't be scattered by nn.DataParallel # so just remove them if they are present. - blacklist_non_batched_params = ["head_mask"] + blacklist_non_batched_params = ["head_mask", "decoder_head_mask"] for k in blacklist_non_batched_params: inputs_dict.pop(k, None) diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index 02313b20dc..e4892f2bb0 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -62,9 +62,9 @@ def prepare_marian_inputs_dict( if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) if head_mask is None: - head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: - decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 7629e2beeb..44ec71e168 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -57,9 +57,9 @@ def prepare_mbart_inputs_dict( if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) if head_mask is None: - head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: - decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 823cb00453..1b5f5632fe 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -49,9 +49,9 @@ def prepare_pegasus_inputs_dict( if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) if head_mask is None: - head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) + head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) if decoder_head_mask is None: - decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) + decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids,