fix test (#9669)
This commit is contained in:
committed by
GitHub
parent
357fb1c5d8
commit
12c1b5b8f4
@@ -61,9 +61,9 @@ def prepare_bart_inputs_dict(
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
|||||||
@@ -48,9 +48,9 @@ def prepare_blenderbot_inputs_dict(
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
|||||||
@@ -56,9 +56,9 @@ def prepare_blenderbot_small_inputs_dict(
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
|||||||
@@ -1073,7 +1073,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# some params shouldn't be scattered by nn.DataParallel
|
# some params shouldn't be scattered by nn.DataParallel
|
||||||
# so just remove them if they are present.
|
# so just remove them if they are present.
|
||||||
blacklist_non_batched_params = ["head_mask"]
|
blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
|
||||||
for k in blacklist_non_batched_params:
|
for k in blacklist_non_batched_params:
|
||||||
inputs_dict.pop(k, None)
|
inputs_dict.pop(k, None)
|
||||||
|
|
||||||
|
|||||||
@@ -62,9 +62,9 @@ def prepare_marian_inputs_dict(
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
|||||||
@@ -57,9 +57,9 @@ def prepare_mbart_inputs_dict(
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
|||||||
@@ -49,9 +49,9 @@ def prepare_pegasus_inputs_dict(
|
|||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
|
||||||
if head_mask is None:
|
if head_mask is None:
|
||||||
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
|
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
|
||||||
if decoder_head_mask is None:
|
if decoder_head_mask is None:
|
||||||
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
|
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
|
||||||
return {
|
return {
|
||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
|
|||||||
Reference in New Issue
Block a user