From b86a71ea381531bedc38aa23ad8e2f6667bc0f41 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Sun, 18 Oct 2020 20:18:08 -0400 Subject: [PATCH] [tests] fix slow bart cnn test, faster marian tests (#7888) --- tests/test_modeling_bart.py | 4 +++- tests/test_modeling_marian.py | 21 ++++++++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/test_modeling_bart.py b/tests/test_modeling_bart.py index aed3495a6d..e3fb9050c1 100644 --- a/tests/test_modeling_bart.py +++ b/tests/test_modeling_bart.py @@ -594,7 +594,9 @@ class BartModelIntegrationTests(unittest.TestCase): "Bronx on Friday. If convicted, she faces up to four years in prison.", ] - generated_summaries = [tok.batch_decode(hypotheses_batch.tolist())] + generated_summaries = tok.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) assert generated_summaries == EXPECTED diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index d139c1d487..46f0cab3df 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -16,7 +16,7 @@ import unittest -from transformers import is_torch_available +from transformers import AutoConfig, AutoTokenizer, MarianConfig, MarianTokenizer, is_torch_available from transformers.file_utils import cached_property from transformers.hf_api import HfApi from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device @@ -25,14 +25,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers if is_torch_available(): import torch - from transformers import ( - AutoConfig, - AutoModelWithLMHead, - AutoTokenizer, - MarianConfig, - MarianMTModel, - MarianTokenizer, - ) + from transformers import AutoModelWithLMHead, MarianMTModel from transformers.convert_marian_to_pytorch import ( ORG_NAME, convert_hf_name_to_opus_name, @@ -79,10 +72,16 @@ class MarianIntegrationTest(unittest.TestCase): @classmethod def setUpClass(cls) -> None: cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}" - cls.tokenizer: MarianTokenizer = AutoTokenizer.from_pretrained(cls.model_name) - cls.eos_token_id = cls.tokenizer.eos_token_id return cls + @cached_property + def tokenizer(self) -> MarianTokenizer: + return AutoTokenizer.from_pretrained(self.model_name) + + @property + def eos_token_id(self) -> int: + return self.tokenizer.eos_token_id + @cached_property def model(self): model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)