Add Gated-SiLU to T5 (#17420)

* Add gated-silu to t5 architecture to support UL2

* Fix error message

* formatting

* formatting again

* refactor

* fix classnames in _init_weights

* remove is_gated

* add test

* fix test

* Try without the test?

* Add back the test.

* Improve error message.

Co-authored-by: Daniel Hesslow <daniel@lighton.ai>
This commit is contained in:
DanielHesslow
2022-06-03 10:56:37 +02:00
committed by GitHub
parent 1c220ced8e
commit 607acd4fbd
5 changed files with 50 additions and 38 deletions

View File

@@ -539,6 +539,12 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_config_and_model_silu_gated(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
config.feed_forward_proj = "gated-silu"
self.model_tester.create_and_check_model(*config_and_inputs)
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)