From 09127c5713e7c1665731836c70d791a628deb9a6 Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Wed, 22 Feb 2023 11:09:04 +0100 Subject: [PATCH] Fix `GPTSanJapaneseModel` (#21731) * fix * skip test_model_parallelism * skip test_model_parallelism --------- Co-authored-by: ydshieh --- .../gptsan_japanese/modeling_gptsan_japanese.py | 2 +- .../gptsan_japanese/test_modeling_gptsan_japanese.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py index 302719b15f..b29c8f566b 100644 --- a/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/gptsan_japanese/modeling_gptsan_japanese.py @@ -924,7 +924,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): `MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns MoEModelOutputWithPastAndCrossAttentions insted of tuple """ - return_dict = return_dict if return_dict is not None else self.config.return_dict + return_dict = return_dict if return_dict is not None else self.config.use_return_dict device = self.position_embeddings.weight.device if input_ids is None: input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None diff --git a/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py b/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py index 228b2715b3..d0c8a090ec 100644 --- a/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py +++ b/tests/models/gptsan_japanese/test_modeling_gptsan_japanese.py @@ -151,6 +151,12 @@ class GPTSanJapaneseTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip( + reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`" + ) + def test_model_parallelism(self): + super().test_model_parallelism() + @require_torch class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): @@ -175,6 +181,12 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + @unittest.skip( + reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`" + ) + def test_model_parallelism(self): + super().test_model_parallelism() + @slow def test_logits(self): model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")