[Flax] Add other BERT classes (#10977)

* add first code structures

* add all bert models

* add to init and docs

* correct docs

* make style
This commit is contained in:
Patrick von Platen
2021-03-31 09:45:58 +03:00
committed by GitHub
parent e031162a6b
commit e87505f3a1
7 changed files with 627 additions and 24 deletions

View File

@@ -23,7 +23,15 @@ from .test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor, random_
if is_flax_available():
from transformers.models.bert.modeling_flax_bert import FlaxBertForMaskedLM, FlaxBertModel
from transformers.models.bert.modeling_flax_bert import (
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForNextSentencePrediction,
FlaxBertForPreTraining,
FlaxBertForQuestionAnswering,
FlaxBertForTokenClassification,
FlaxBertModel,
)
class FlaxBertModelTester(unittest.TestCase):
@@ -48,6 +56,7 @@ class FlaxBertModelTester(unittest.TestCase):
type_vocab_size=16,
type_sequence_label_size=2,
initializer_range=0.02,
num_choices=4,
):
self.parent = parent
self.batch_size = batch_size
@@ -68,6 +77,7 @@ class FlaxBertModelTester(unittest.TestCase):
self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = initializer_range
self.num_choices = num_choices
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -107,7 +117,20 @@ class FlaxBertModelTester(unittest.TestCase):
@require_flax
class FlaxBertModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxBertModel, FlaxBertForMaskedLM) if is_flax_available() else ()
all_model_classes = (
(
FlaxBertModel,
FlaxBertForPreTraining,
FlaxBertForMaskedLM,
FlaxBertForMultipleChoice,
FlaxBertForQuestionAnswering,
FlaxBertForNextSentencePrediction,
FlaxBertForTokenClassification,
FlaxBertForQuestionAnswering,
)
if is_flax_available()
else ()
)
def setUp(self):
self.model_tester = FlaxBertModelTester(self)

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import random
import tempfile
@@ -65,6 +66,18 @@ class FlaxModelTesterMixin:
model_tester = None
all_model_classes = ()
def _prepare_for_class(self, inputs_dict, model_class):
inputs_dict = copy.deepcopy(inputs_dict)
# hack for now until we have AutoModel classes
if "ForMultipleChoice" in model_class.__name__:
inputs_dict = {
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
for k, v in inputs_dict.items()
}
return inputs_dict
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
diff = np.abs((a - b)).max()
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
@@ -75,6 +88,7 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
pt_model_class = getattr(transformers, pt_model_class_name)
pt_model = pt_model_class(config).eval()
@@ -83,12 +97,12 @@ class FlaxModelTesterMixin:
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
@@ -97,7 +111,7 @@ class FlaxModelTesterMixin:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
@@ -111,13 +125,14 @@ class FlaxModelTesterMixin:
with self.subTest(model_class.__name__):
model = model_class(config)
outputs = model(**inputs_dict)
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs_dict)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_loaded = model_class.from_pretrained(tmpdirname)
outputs_loaded = model_loaded(**inputs_dict)
outputs_loaded = model_loaded(**prepared_inputs_dict)
for output_loaded, output in zip(outputs_loaded, outputs):
self.assert_almost_equals(output_loaded, output, 5e-3)
@@ -126,6 +141,7 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
model = model_class(config)
@jax.jit
@@ -134,10 +150,10 @@ class FlaxModelTesterMixin:
with self.subTest("JIT Disabled"):
with jax.disable_jit():
outputs = model_jitted(**inputs_dict)
outputs = model_jitted(**prepared_inputs_dict)
with self.subTest("JIT Enabled"):
jitted_outputs = model_jitted(**inputs_dict)
jitted_outputs = model_jitted(**prepared_inputs_dict)
self.assertEqual(len(outputs), len(jitted_outputs))
for jitted_output, output in zip(jitted_outputs, outputs):