[BART] FP16 testing fixes (#3266)

This commit is contained in:
Sam Shleifer
2020-03-13 19:48:26 -04:00
committed by GitHub
parent 8320feec09
commit 2bd79e23de
2 changed files with 16 additions and 4 deletions

View File

@@ -82,7 +82,7 @@ LARGE_NEGATIVE = -1e8
def _prepare_bart_decoder_inputs( def _prepare_bart_decoder_inputs(
config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, config, input_ids, decoder_input_ids=None, decoder_attn_mask=None, mask_dtype=None,
): ):
"""Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if """Prepare masks that ignore padding tokens decoder and a causal lm mask for the decoder if
none are provided. This mimics the default behavior in fairseq. To override it pass in masks. none are provided. This mimics the default behavior in fairseq. To override it pass in masks.
@@ -101,6 +101,8 @@ def _prepare_bart_decoder_inputs(
new_shape = (bsz, tgt_len, tgt_len) new_shape = (bsz, tgt_len, tgt_len)
# make it broadcastable so can just be added to the attention coefficients # make it broadcastable so can just be added to the attention coefficients
decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device) decoder_attn_mask = _combine_masks(decoder_padding_mask, causal_lm_mask, new_shape).to(device=input_ids.device)
if mask_dtype is not None:
decoder_attn_mask = decoder_attn_mask.to(mask_dtype)
assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len) assert decoder_attn_mask is None or decoder_attn_mask.shape == (bsz, 1, tgt_len, tgt_len)
return decoder_input_ids, decoder_attn_mask return decoder_input_ids, decoder_attn_mask
@@ -838,7 +840,11 @@ class BartModel(PretrainedBartModel):
# make masks if user doesn't supply # make masks if user doesn't supply
if not self.decoder.generation_mode: if not self.decoder.generation_mode:
decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs( decoder_input_ids, decoder_attention_mask = _prepare_bart_decoder_inputs(
self.config, input_ids, decoder_input_ids=decoder_input_ids, decoder_attn_mask=decoder_attention_mask, self.config,
input_ids,
decoder_input_ids=decoder_input_ids,
decoder_attn_mask=decoder_attention_mask,
mask_dtype=self.shared.weight.dtype,
) )
assert decoder_input_ids is not None assert decoder_input_ids is not None
if encoder_outputs is None: if encoder_outputs is None:

View File

@@ -314,10 +314,16 @@ class BartHeadTests(unittest.TestCase):
@unittest.skipIf(torch_device == "cpu", "Cant do half precision") @unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_generate_fp16(self): def test_generate_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=True) config, input_ids, batch_size = self._get_config_and_data(output_past=True)
input_ids = input_ids attention_mask = input_ids.ne(1).to(torch_device)
model = BartForConditionalGeneration(config).eval().to(torch_device).half()
model.generate(input_ids, attention_mask=attention_mask, do_sample=False, early_stopping=True)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_base_model_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=False)
attention_mask = input_ids.ne(1).to(torch_device) attention_mask = input_ids.ne(1).to(torch_device)
lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half() lm_model = BartForConditionalGeneration(config).eval().to(torch_device).half()
lm_model.generate(input_ids, attention_mask=attention_mask) lm_model(input_ids, attention_mask=attention_mask)
def test_prepare_bart_decoder_inputs(self): def test_prepare_bart_decoder_inputs(self):
config, *_ = self._get_config_and_data(output_past=False) config, *_ = self._get_config_and_data(output_past=False)