This commit is contained in:
Patrick von Platen
2021-01-19 09:06:24 +01:00
committed by GitHub
parent 357fb1c5d8
commit 12c1b5b8f4
7 changed files with 13 additions and 13 deletions

View File

@@ -57,9 +57,9 @@ def prepare_mbart_inputs_dict(
if decoder_attention_mask is None:
decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id)
if head_mask is None:
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads)
head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device)
if decoder_head_mask is None:
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads)
decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device)
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,