New BartModel (#2745)
* Results same as fairseq * Wrote a ton of tests * Struggled with api signatures * added some docs
This commit is contained in:
@@ -142,10 +142,17 @@ class ModelTesterMixin:
|
||||
out_len = len(outputs)
|
||||
|
||||
if self.is_encoder_decoder:
|
||||
self.assertEqual(out_len % 2, 0)
|
||||
decoder_attentions = outputs[(out_len // 2) - 1]
|
||||
self.assertEqual(model.config.output_attentions, True)
|
||||
self.assertEqual(model.config.output_hidden_states, False)
|
||||
correct_outlen = (
|
||||
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
|
||||
)
|
||||
decoder_attention_idx = 1
|
||||
if "lm_labels" in inputs_dict or "decoder_lm_labels" in inputs_dict: # loss will come first
|
||||
correct_outlen += 1 # compute loss
|
||||
decoder_attention_idx += 1
|
||||
self.assertEqual(out_len, correct_outlen)
|
||||
|
||||
decoder_attentions = outputs[decoder_attention_idx]
|
||||
self.assertIsInstance(decoder_attentions, (list, tuple))
|
||||
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(decoder_attentions[0].shape[-3:]),
|
||||
@@ -562,15 +569,16 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
if not self.is_encoder_decoder:
|
||||
input_ids = inputs_dict["input_ids"]
|
||||
del inputs_dict["input_ids"]
|
||||
else:
|
||||
encoder_input_ids = inputs_dict["encoder_input_ids"]
|
||||
decoder_input_ids = inputs_dict["decoder_input_ids"]
|
||||
decoder_input_ids = inputs_dict.get("decoder_input_ids", encoder_input_ids)
|
||||
del inputs_dict["encoder_input_ids"]
|
||||
del inputs_dict["decoder_input_ids"]
|
||||
inputs_dict.pop("decoder_input_ids", None)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
Reference in New Issue
Block a user