Add TFData2VecVision for semantic segmentation (#17271)
* feat: initial implementation of data2vec segmentation model in TF. * chore: minor corrections to make the segmenter work. * chore: removed unncessary files. * chore: add tests and other modifications. * fix: loss computation for segmentation. * chore: remove unused variable. * chore: formatting. * added a dummy adaptive pooling layer. * removed unnecessary file. * potentially add identifiers to layer names. * fix: layer naming. * chore: removed unnecessary print. * Skipping unneeded test * chore: add logging to debug tolerance. * fix: segmentation tests for tfdata2vecvision * chore: make style. * fix: layer names, assertion to be resolved. * Bumping test tolerance a bit * chore: bump the tol in PT test. Co-authored-by: matt <rocketknight1@gmail.com>
This commit is contained in:
@@ -389,6 +389,10 @@ class Data2VecVisionModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
check_hidden_states_output(inputs_dict, config, model_class)
|
||||
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=2e-4, name="outputs", attributes=None):
|
||||
# We override with a slightly higher tol value, as semseg models tend to diverge a bit more
|
||||
super().check_pt_tf_outputs(tf_outputs, pt_outputs, model_class, tol, name, attributes)
|
||||
|
||||
def test_for_image_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)
|
||||
|
||||
Reference in New Issue
Block a user