This commit is contained in:
Patrick von Platen
2021-01-19 09:06:24 +01:00
committed by GitHub
parent 357fb1c5d8
commit 12c1b5b8f4
7 changed files with 13 additions and 13 deletions

View File

@@ -61,9 +61,9 @@ def prepare_bart_inputs_dict(
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None: if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,

View File

@@ -48,9 +48,9 @@ def prepare_blenderbot_inputs_dict(
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None: if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,

View File

@@ -56,9 +56,9 @@ def prepare_blenderbot_small_inputs_dict(
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None: if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,

View File

@@ -1073,7 +1073,7 @@ class ModelTesterMixin:
# some params shouldn't be scattered by nn.DataParallel # some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present. # so just remove them if they are present.
blacklist_non_batched_params = ["head_mask"] blacklist_non_batched_params = ["head_mask", "decoder_head_mask"]
for k in blacklist_non_batched_params: for k in blacklist_non_batched_params:
inputs_dict.pop(k, None) inputs_dict.pop(k, None)

View File

@@ -62,9 +62,9 @@ def prepare_marian_inputs_dict(
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None: if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,

View File

@@ -57,9 +57,9 @@ def prepare_mbart_inputs_dict(
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None: if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,

View File

@@ -49,9 +49,9 @@ def prepare_pegasus_inputs_dict(
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None: if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads) head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None: if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads) decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,