added GPTNeoForTokenClassification (#22908)

* added GPTNeoForTokenClassification

* add to top-level init

* fixup

* test

* more fixup

* add to gpt_neo.mdx

* repo consistency

* dummy copy

* fix copies

* optax >= 0.1.5 assumes jax.Array exists - which it doesn't for jax <= 0.3.6

* merge with main made this superfluous

* added classifier_dropout

* remove legacy code

* removed fmt:on/off
removed expected_outputs

* doc style fix

* classifier_dropout is always in config

---------

Co-authored-by: Prof. Peter Schneider-Kamp <jps@ordbogen.com>
This commit is contained in:
peter-sk
2023-04-27 18:10:03 +02:00
committed by GitHub
parent 614e191c4d
commit d65b14ed67
9 changed files with 129 additions and 2 deletions

View File

@@ -74,6 +74,11 @@ The `generate()` method can be used to generate text using GPT Neo model.
[[autodoc]] GPTNeoForSequenceClassification
- forward
## GPTNeoForTokenClassification
[[autodoc]] GPTNeoForTokenClassification
- forward
## FlaxGPTNeoModel
[[autodoc]] FlaxGPTNeoModel