[tests] fix slow bart cnn test, faster marian tests (#7888)
This commit is contained in:
@@ -16,7 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers import AutoConfig, AutoTokenizer, MarianConfig, MarianTokenizer, is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.hf_api import HfApi
|
||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||
@@ -25,14 +25,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
MarianConfig,
|
||||
MarianMTModel,
|
||||
MarianTokenizer,
|
||||
)
|
||||
from transformers import AutoModelWithLMHead, MarianMTModel
|
||||
from transformers.convert_marian_to_pytorch import (
|
||||
ORG_NAME,
|
||||
convert_hf_name_to_opus_name,
|
||||
@@ -79,10 +72,16 @@ class MarianIntegrationTest(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls) -> None:
|
||||
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
|
||||
cls.tokenizer: MarianTokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.eos_token_id = cls.tokenizer.eos_token_id
|
||||
return cls
|
||||
|
||||
@cached_property
|
||||
def tokenizer(self) -> MarianTokenizer:
|
||||
return AutoTokenizer.from_pretrained(self.model_name)
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> int:
|
||||
return self.tokenizer.eos_token_id
|
||||
|
||||
@cached_property
|
||||
def model(self):
|
||||
model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)
|
||||
|
||||
Reference in New Issue
Block a user