Copied from for test files (#26713)
* copied statement for test files --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -39,7 +39,7 @@ if is_torch_available():
|
||||
)
|
||||
|
||||
|
||||
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTester with Llama->Persimmon
|
||||
class PersimmonModelTester:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -266,7 +266,6 @@ class PersimmonModelTester:
|
||||
return config, inputs_dict
|
||||
|
||||
|
||||
# Copied from transformers.tests.llama.test_modelling_llama.LlamaModelTest with Llama->Persimmon
|
||||
@require_torch
|
||||
class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (
|
||||
@@ -288,23 +287,28 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
test_headmasking = False
|
||||
test_pruning = False
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.setUp with Llama->Persimmon
|
||||
def setUp(self):
|
||||
self.model_tester = PersimmonModelTester(self)
|
||||
self.config_tester = ConfigTester(self, config_class=PersimmonConfig, hidden_size=37)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_config
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_various_embeddings
|
||||
def test_model_various_embeddings(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
for type in ["absolute", "relative_key", "relative_key_query"]:
|
||||
config_and_inputs[0].position_embedding_type = type
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model with Llama->Persimmon,llama->persimmon
|
||||
def test_persimmon_sequence_classification_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
@@ -317,6 +321,7 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_single_label with Llama->Persimmon,llama->persimmon
|
||||
def test_persimmon_sequence_classification_model_for_single_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
@@ -330,6 +335,7 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_llama_sequence_classification_model_for_multi_label with Llama->Persimmon,llama->persimmon
|
||||
def test_persimmon_sequence_classification_model_for_multi_label(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.num_labels = 3
|
||||
@@ -346,10 +352,12 @@ class PersimmonModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
|
||||
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
|
||||
|
||||
@unittest.skip("Persimmon buffers include complex numbers, which breaks this test")
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_save_load_fast_init_from_base
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
pass
|
||||
|
||||
@parameterized.expand([("linear",), ("dynamic",)])
|
||||
# Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Persimmon
|
||||
def test_model_rope_scaling(self, scaling_type):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
short_input = ids_tensor([1, 10], config.vocab_size)
|
||||
|
||||
Reference in New Issue
Block a user