[T5 failing CI] Fix generate test (#11770)

* fix_torch_device_generate_test

* remove @
This commit is contained in:
Patrick von Platen
2021-05-19 10:31:17 +01:00
committed by GitHub
parent 680d181ce8
commit 43891be19b
2 changed files with 15 additions and 8 deletions

View File

@@ -1084,9 +1084,13 @@ class GenerationTesterMixin:
continue
head_masking = {
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads),
"decoder_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
"cross_attn_head_mask": torch.zeros(config.decoder_layers, config.decoder_attention_heads),
"head_mask": torch.zeros(config.encoder_layers, config.encoder_attention_heads, device=torch_device),
"decoder_head_mask": torch.zeros(
config.decoder_layers, config.decoder_attention_heads, device=torch_device
),
"cross_attn_head_mask": torch.zeros(
config.decoder_layers, config.decoder_attention_heads, device=torch_device
),
}
signature = inspect.signature(model.forward)