Fix GPTSanJapaneseModel (#21731)
* fix * skip test_model_parallelism * skip test_model_parallelism --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -924,7 +924,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel):
|
|||||||
`MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns
|
`MoEModelOutputWithPastAndCrossAttentions` or `tuple` if `return_dict` returns
|
||||||
MoEModelOutputWithPastAndCrossAttentions insted of tuple
|
MoEModelOutputWithPastAndCrossAttentions insted of tuple
|
||||||
"""
|
"""
|
||||||
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
device = self.position_embeddings.weight.device
|
device = self.position_embeddings.weight.device
|
||||||
if input_ids is None:
|
if input_ids is None:
|
||||||
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
|
input_ids = torch.zeros([1, 1]).int().to(device) # dummy for input_ids was None
|
||||||
|
|||||||
@@ -151,6 +151,12 @@ class GPTSanJapaneseTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
|
||||||
|
)
|
||||||
|
def test_model_parallelism(self):
|
||||||
|
super().test_model_parallelism()
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||||
@@ -175,6 +181,12 @@ class GPTSanJapaneseForConditionalGenerationTest(ModelTesterMixin, GenerationTes
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
@unittest.skip(
|
||||||
|
reason="skip for now as the computed `max_memory` by `model_split_percents` in the test method will be changed inside `from_pretrained`"
|
||||||
|
)
|
||||||
|
def test_model_parallelism(self):
|
||||||
|
super().test_model_parallelism()
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_logits(self):
|
def test_logits(self):
|
||||||
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
model = GPTSanJapaneseForConditionalGeneration.from_pretrained("Tanrei/GPTSAN-japanese")
|
||||||
|
|||||||
Reference in New Issue
Block a user