diff --git a/tests/models/codegen/test_modeling_codegen.py b/tests/models/codegen/test_modeling_codegen.py index 1b4cdca6c3..fdf2d89192 100644 --- a/tests/models/codegen/test_modeling_codegen.py +++ b/tests/models/codegen/test_modeling_codegen.py @@ -18,7 +18,8 @@ import datetime import unittest from transformers import CodeGenConfig, is_torch_available -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.file_utils import cached_property +from transformers.testing_utils import is_flaky, require_torch, slow, torch_device from ...generation.test_utils import GenerationTesterMixin from ...test_configuration_common import ConfigTester @@ -462,11 +463,19 @@ class CodeGenModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi @require_torch class CodeGenModelLanguageGenerationTest(unittest.TestCase): + @cached_property + def cached_tokenizer(self): + return AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") + + @cached_property + def cached_model(self): + return CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono") + @slow def test_lm_generate_codegen(self): - tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") + tokenizer = self.cached_tokenizer for checkpointing in [True, False]: - model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono") + model = self.cached_model if checkpointing: model.gradient_checkpointing_enable() @@ -484,8 +493,8 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase): @slow def test_codegen_sample(self): - tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") - model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono") + tokenizer = self.cached_tokenizer + model = self.cached_model model.to(torch_device) torch.manual_seed(0) @@ -515,10 +524,11 @@ class CodeGenModelLanguageGenerationTest(unittest.TestCase): all([output_seq_strs[idx] != output_seq_tt_strs[idx] for idx in range(len(output_seq_tt_strs))]) ) # token_type_ids should change output + @is_flaky(max_attempts=3, description="measure of timing is somehow flaky.") @slow def test_codegen_sample_max_time(self): - tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-350M-mono") - model = CodeGenForCausalLM.from_pretrained("Salesforce/codegen-350M-mono") + tokenizer = self.cached_tokenizer + model = self.cached_model model.to(torch_device) torch.manual_seed(0)