[Flax] Add FlaxMBart (#12236)

* Copy BART to MBart and rename some stuff

* Add copy statements pointing to FlaxBart

* Update/add some common files

* Update shift_tokens_rigth + fix imports

* Fix shift_tokens_right method according to MBart implementation

* Update shift_tokens_right in tests accordingly

* Fix the import issue and update docs file
* make style quality

* Do some minor changes according to patil-suraj suggestions

* Change the order of normalization layer and attention

* Add some copu statementes

* Update generate method and add integration test for mBart

* Make a few updates after a review

Besides, add `lang_code_to_id` to MBartTokenizeFast

* fix-copies; make style quality

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* fix output type, style

* add copied from

* resolve conflicts

Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Daniel Stancl
2021-07-07 08:50:38 +02:00
committed by GitHub
parent 2d42915abe
commit 61400e1ec7
9 changed files with 2336 additions and 1 deletions

View File

@@ -421,7 +421,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mBART | ✅ | ✅ | ✅ | ✅ | |
| mBART | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mT5 | ✅ | ✅ | ✅ | ✅ | ❌ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+

View File

@@ -240,3 +240,31 @@ TFMBartForConditionalGeneration
.. autoclass:: transformers.TFMBartForConditionalGeneration
:members: call
FlaxMBartModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxMBartModel
:members: __call__, encode, decode
FlaxMBartForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxMBartForConditionalGeneration
:members: __call__, encode, decode
FlaxMBartForSequenceClassification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxMBartForSequenceClassification
:members: __call__, encode, decode
FlaxMBartForQuestionAnswering
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxMBartForQuestionAnswering
:members: __call__, encode, decode