[tests] fix slow bart cnn test, faster marian tests (#7888)
This commit is contained in:
@@ -594,7 +594,9 @@ class BartModelIntegrationTests(unittest.TestCase):
|
|||||||
"Bronx on Friday. If convicted, she faces up to four years in prison.",
|
"Bronx on Friday. If convicted, she faces up to four years in prison.",
|
||||||
]
|
]
|
||||||
|
|
||||||
generated_summaries = [tok.batch_decode(hypotheses_batch.tolist())]
|
generated_summaries = tok.batch_decode(
|
||||||
|
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
|
||||||
|
)
|
||||||
assert generated_summaries == EXPECTED
|
assert generated_summaries == EXPECTED
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import AutoConfig, AutoTokenizer, MarianConfig, MarianTokenizer, is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.hf_api import HfApi
|
from transformers.hf_api import HfApi
|
||||||
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
|
||||||
@@ -25,14 +25,7 @@ from transformers.testing_utils import require_sentencepiece, require_tokenizers
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from transformers import AutoModelWithLMHead, MarianMTModel
|
||||||
AutoConfig,
|
|
||||||
AutoModelWithLMHead,
|
|
||||||
AutoTokenizer,
|
|
||||||
MarianConfig,
|
|
||||||
MarianMTModel,
|
|
||||||
MarianTokenizer,
|
|
||||||
)
|
|
||||||
from transformers.convert_marian_to_pytorch import (
|
from transformers.convert_marian_to_pytorch import (
|
||||||
ORG_NAME,
|
ORG_NAME,
|
||||||
convert_hf_name_to_opus_name,
|
convert_hf_name_to_opus_name,
|
||||||
@@ -79,10 +72,16 @@ class MarianIntegrationTest(unittest.TestCase):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls) -> None:
|
def setUpClass(cls) -> None:
|
||||||
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
|
cls.model_name = f"Helsinki-NLP/opus-mt-{cls.src}-{cls.tgt}"
|
||||||
cls.tokenizer: MarianTokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
|
||||||
cls.eos_token_id = cls.tokenizer.eos_token_id
|
|
||||||
return cls
|
return cls
|
||||||
|
|
||||||
|
@cached_property
|
||||||
|
def tokenizer(self) -> MarianTokenizer:
|
||||||
|
return AutoTokenizer.from_pretrained(self.model_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def eos_token_id(self) -> int:
|
||||||
|
return self.tokenizer.eos_token_id
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
def model(self):
|
def model(self):
|
||||||
model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)
|
model: MarianMTModel = AutoModelWithLMHead.from_pretrained(self.model_name).to(torch_device)
|
||||||
|
|||||||
Reference in New Issue
Block a user