Fix GPTSanJapaneseModel (#21731)

* fix

* skip test_model_parallelism

* skip test_model_parallelism

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2023-02-22 11:09:04 +01:00
committed by GitHub
parent aff87da15b
commit 09127c5713
2 changed files with 13 additions and 1 deletions

View File

@@ -924,7 +924,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
`MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns `MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns
MoEModelOutputWithPastAndCrossAttentions insted of tuple MoEModelOutputWithPastAndCrossAttentions insted of tuple
""" """
return_dict = return_dict if return_dict is not None else self.config.return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
device = self.position_embeddings.weight.device device = self.position_embeddings.weight.device
if input_ids is None: if input_ids is None:
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None

View File

@@ -151,6 +151,12 @@ class GPTSanJapaneseTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
)
def test_model_parallelism(self):
super().test_model_parallelism()
@require_torch @require_torch
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
@@ -175,6 +181,12 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
)
def test_model_parallelism(self):
super().test_model_parallelism()
@slow @slow
def test_logits(self): def test_logits(self):
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese") model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")