Fix (non-slow) tests on GPU (torch) (#3024)

* Fix tests on GPU (torch)

* Fix bart slow tests

Co-authored-by: Sam Shleifer <sshleifer@gmail.com>
This commit is contained in:
Julien Chaumond
2020-02-26 11:59:25 -05:00
committed by GitHub
parent 9df74b8bc4
commit 9cda3620b6
4 changed files with 26 additions and 13 deletions

View File

@@ -86,7 +86,7 @@ def _prepare_bart_decoder_inputs(
causal_lm_mask = None
new_shape = (bsz, tgt_len, tgt_len)
# make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape)
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
return decoder_input_ids, decoder_attn_mask