From 8af1970e45dc86688b70e1f3fa76d6cc5eca94e9 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Mon, 31 Aug 2020 16:10:43 -0400 Subject: [PATCH] Fix marian slow test (#6854) --- tests/test_modeling_marian.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index e0b9cc7c2a..a264111457 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -38,6 +38,7 @@ if is_torch_available(): convert_hf_name_to_opus_name, convert_opus_name_to_hf_name, ) + from transformers.modeling_bart import shift_tokens_right from transformers.pipelines import TranslationPipeline @@ -116,18 +117,21 @@ class TestMarian_EN_DE_More(MarianIntegrationTest): expected_ids = [38, 121, 14, 697, 38848, 0] model_inputs: dict = self.tokenizer.prepare_seq2seq_batch(src, tgt_texts=tgt).to(torch_device) + self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist()) desired_keys = { "input_ids", "attention_mask", - "decoder_input_ids", - "decoder_attention_mask", + "labels", } self.assertSetEqual(desired_keys, set(model_inputs.keys())) + model_inputs["decoder_input_ids"] = shift_tokens_right(model_inputs.labels, self.tokenizer.pad_token_id) + model_inputs["return_dict"] = True + model_inputs["use_cache"] = False with torch.no_grad(): - logits, *enc_features = self.model(**model_inputs) - max_indices = logits.argmax(-1) + outputs = self.model(**model_inputs) + max_indices = outputs.logits.argmax(-1) self.tokenizer.batch_decode(max_indices) def test_unk_support(self):