Flax Big Bird (#11967)

* add flax bert

* bert -> bigbird

* original_full ported

* add debugger

* init block sparse

* fix copies ; gelu_fast -> gelu_new

* block sparse port

* fix block sparse

* block sparse working

* all ckpts working

* fix-copies

* make quality

* init tests

* temporary fix for FlaxBigBirdForMultipleChoice

* skip test_attention_outputs

* fix

* gelu_fast -> gelu_new ; fix multiple choice model

* remove nsp

* fix sequence classifier

* fix

* make quality

* make fix-copies

* finish

* Delete debugger.ipynb

* Update src/transformers/models/big_bird/modeling_flax_big_bird.py

* make style

* finish

* bye bye jit flax tests

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Vasudev Gupta
2021-06-15 00:31:03 +05:30
committed by GitHub
parent a156da9a23
commit d9c0d08f9a
17 changed files with 2434 additions and 17 deletions

View File

@@ -23,7 +23,7 @@ import numpy as np
import transformers
from transformers import is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
if is_flax_available():
@@ -273,6 +273,7 @@ class FlaxModelTesterMixin:
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 1e-3)
@slow
def test_jit_compilation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()