Add TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING (#18469)
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ from ...test_modeling_tf_common import TFModelTesterMixin, floats_tensor, ids_te
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import TFSegformerForImageClassification, TFSegformerForSemanticSegmentation, TFSegformerModel
|
||||
@@ -336,6 +337,9 @@ class TFSegformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_dataset_conversion(self):
|
||||
super().test_dataset_conversion()
|
||||
|
||||
def check_keras_fit_results(self, val_loss1, val_loss2, atol=2e-1, rtol=2e-1):
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
|
||||
reason="TF (<=2.8) does not support backprop for grouped convolutions on CPU.",
|
||||
|
||||
@@ -62,11 +62,13 @@ if is_tf_available():
|
||||
from transformers import (
|
||||
TF_MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING,
|
||||
TF_MODEL_FOR_MASKED_LM_MAPPING,
|
||||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
|
||||
TF_MODEL_FOR_PRETRAINING_MAPPING,
|
||||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING,
|
||||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
@@ -170,6 +172,15 @@ class TFModelTesterMixin:
|
||||
inputs_dict["labels"] = tf.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
||||
)
|
||||
elif model_class in get_values(TF_MODEL_FOR_MASKED_IMAGE_MODELING_MAPPING):
|
||||
num_patches = self.model_tester.image_size // self.model_tester.patch_size
|
||||
inputs_dict["bool_masked_pos"] = tf.zeros(
|
||||
(self.model_tester.batch_size, num_patches**2), dtype=tf.int32
|
||||
)
|
||||
elif model_class in get_values(TF_MODEL_FOR_SEMANTIC_SEGMENTATION_MAPPING):
|
||||
batch_size, num_channels, height, width = inputs_dict["pixel_values"].shape
|
||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, height, width), dtype=tf.int32)
|
||||
|
||||
return inputs_dict
|
||||
|
||||
def test_initialization(self):
|
||||
@@ -1389,6 +1400,9 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
|
||||
|
||||
def check_keras_fit_results(self, val_loss1, val_loss2, atol=1e-2, rtol=1e-3):
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=atol, rtol=rtol))
|
||||
|
||||
def test_keras_fit(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -1468,7 +1482,7 @@ class TFModelTesterMixin:
|
||||
val_loss2 = history2.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss2))
|
||||
accuracy2 = {key: val[0] for key, val in history2.history.items() if key.endswith("accuracy")}
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss2, atol=1e-2, rtol=1e-3))
|
||||
self.check_keras_fit_results(val_loss1, val_loss2)
|
||||
self.assertEqual(history1.history.keys(), history2.history.keys())
|
||||
for key in history1.history.keys():
|
||||
if not key.startswith("val_"):
|
||||
@@ -1494,7 +1508,7 @@ class TFModelTesterMixin:
|
||||
val_loss3 = history3.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss3))
|
||||
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
|
||||
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
|
||||
self.check_keras_fit_results(val_loss1, val_loss3)
|
||||
self.assertEqual(history1.history.keys(), history3.history.keys())
|
||||
if metrics:
|
||||
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
|
||||
|
||||
Reference in New Issue
Block a user