Flax Big Bird (#11967)
* add flax bert * bert -> bigbird * original_full ported * add debugger * init block sparse * fix copies ; gelu_fast -> gelu_new * block sparse port * fix block sparse * block sparse working * all ckpts working * fix-copies * make quality * init tests * temporary fix for FlaxBigBirdForMultipleChoice * skip test_attention_outputs * fix * gelu_fast -> gelu_new ; fix multiple choice model * remove nsp * fix sequence classifier * fix * make quality * make fix-copies * finish * Delete debugger.ipynb * Update src/transformers/models/big_bird/modeling_flax_big_bird.py * make style * finish * bye bye jit flax tests Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -23,7 +23,7 @@ import numpy as np
|
||||
import transformers
|
||||
from transformers import is_flax_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax
|
||||
from transformers.testing_utils import is_pt_flax_cross_test, require_flax, slow
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
@@ -273,6 +273,7 @@ class FlaxModelTesterMixin:
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 1e-3)
|
||||
|
||||
@slow
|
||||
def test_jit_compilation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user