Fix error in mixed precision training of TFCvtModel (#22267)
* Make sure CVT can be trained using mixed precision * Add test for keras-fit with mixed-precision * Update tests/models/cvt/test_modeling_tf_cvt.py Co-authored-by: Matt <Rocketknight1@users.noreply.github.com> --------- Co-authored-by: gcuder <Gerald.Cuder@iacapps.com> Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
This commit is contained in:
@@ -186,6 +186,12 @@ class TFCvtModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
def test_keras_fit(self):
|
||||
super().test_keras_fit()
|
||||
|
||||
def test_keras_fit_mixed_precision(self):
|
||||
policy = tf.keras.mixed_precision.Policy("mixed_float16")
|
||||
tf.keras.mixed_precision.set_global_policy(policy)
|
||||
super().test_keras_fit()
|
||||
tf.keras.mixed_precision.set_global_policy("float32")
|
||||
|
||||
def test_forward_signature(self):
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user