Fix GPTSanJapaneseModel (#21731)

* fix

* skip test_model_parallelism

* skip test_model_parallelism

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-02-22 11:09:04 +01:00
committed by GitHub
parent aff87da15b
commit 09127c5713
2 changed files with 13 additions and 1 deletions

View File

@@ -151,6 +151,12 @@ class GPTSanJapaneseTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
)
def test_model_parallelism(self):
super().test_model_parallelism()
@require_torch
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@@ -175,6 +181,12 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
)
def test_model_parallelism(self):
super().test_model_parallelism()
@slow
def test_logits(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")