Copied from for test files (#26713)

* copied statement for test files

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-10-11 14:12:09 +02:00
committed by GitHub
parent 9f40639292
commit 5334796d20
14 changed files with 127 additions and 45 deletions

View File

@@ -39,7 +39,7 @@ if is_torch_available():
)
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
class PersimmonModelTester:
def __init__(
self,
@@ -266,7 +266,6 @@ class PersimmonModelTester:
return config, inputs_dict
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
@require_torch
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (
@@ -288,23 +287,28 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
test_headmasking = False
test_pruning = False
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon
def setUp(self):
self.model_tester = PersimmonModelTester(self)
self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
def test_model_various_embeddings(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
for type in ["absolute", "relative_key", "relative_key_query"]:
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@@ -317,6 +321,7 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model_for_single_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@@ -330,6 +335,7 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon
def test_persimmon_sequence_classification_model_for_multi_label(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.num_labels = 3
@@ -346,10 +352,12 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
def test_save_load_fast_init_from_base(self):
pass
@parameterized.expand([("linear",), ("dynamic",)])
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon
def test_model_rope_scaling(self, scaling_type):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
short_input = ids_tensor([1, 10], config.vocab_size)