[Flax] Add other BERT classes (#10977)
* add first code structures * add all bert models * add to init and docs * correct docs * make style
This commit is contained in:
committed by
GitHub
parent
e031162a6b
commit
e87505f3a1
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import random
|
||||
import tempfile
|
||||
|
||||
@@ -65,6 +66,18 @@ class FlaxModelTesterMixin:
|
||||
model_tester = None
|
||||
all_model_classes = ()
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
# hack for now until we have AutoModel classes
|
||||
if "ForMultipleChoice" in model_class.__name__:
|
||||
inputs_dict = {
|
||||
k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
|
||||
@@ -75,6 +88,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
pt_model = pt_model_class(config).eval()
|
||||
@@ -83,12 +97,12 @@ class FlaxModelTesterMixin:
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in inputs_dict.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
|
||||
fx_outputs = fx_model(**inputs_dict)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 2e-3)
|
||||
@@ -97,7 +111,7 @@ class FlaxModelTesterMixin:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**inputs_dict)
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
@@ -111,13 +125,14 @@ class FlaxModelTesterMixin:
|
||||
with self.subTest(model_class.__name__):
|
||||
model = model_class(config)
|
||||
|
||||
outputs = model(**inputs_dict)
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(**prepared_inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_loaded = model_class.from_pretrained(tmpdirname)
|
||||
|
||||
outputs_loaded = model_loaded(**inputs_dict)
|
||||
outputs_loaded = model_loaded(**prepared_inputs_dict)
|
||||
for output_loaded, output in zip(outputs_loaded, outputs):
|
||||
self.assert_almost_equals(output_loaded, output, 5e-3)
|
||||
|
||||
@@ -126,6 +141,7 @@ class FlaxModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
@jax.jit
|
||||
@@ -134,10 +150,10 @@ class FlaxModelTesterMixin:
|
||||
|
||||
with self.subTest("JIT Disabled"):
|
||||
with jax.disable_jit():
|
||||
outputs = model_jitted(**inputs_dict)
|
||||
outputs = model_jitted(**prepared_inputs_dict)
|
||||
|
||||
with self.subTest("JIT Enabled"):
|
||||
jitted_outputs = model_jitted(**inputs_dict)
|
||||
jitted_outputs = model_jitted(**prepared_inputs_dict)
|
||||
|
||||
self.assertEqual(len(outputs), len(jitted_outputs))
|
||||
for jitted_output, output in zip(jitted_outputs, outputs):
|
||||
|
||||
Reference in New Issue
Block a user