[Flax] Add FlaxBlenderbot (#13633)

* Init Flax implementation for Blenderbot

* Add a majority of stuff except for tests

* make style quality

* Add tests and fix some bugs

* Add tests

* Clean source code and fix some bugs

* Fix copies and docs

* Fix jax device condition for tests

* Fix layer norm in the encoder

* Fix a few typos in the test file

* make fix-copies

* make fix-copies

* fix layer norm

* Fix Flax params dtype (#13090)

* Fix PR reference (#13098)

* make fix-copies

* Update tests/test_modeling_flax_blenderbot.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Daniel Stancl
2021-11-30 13:06:54 +01:00
committed by GitHub
parent 254fef67cf
commit faacd74729
13 changed files with 2026 additions and 30 deletions

View File

@@ -477,6 +477,13 @@ else:
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax
jax_device = jax.default_backend()
else:
jax_device = None
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""