Fix inverted conditional in TF common test! (#22540)
* Fix inverted conditional in TF common test! * Make the same change in the PT tests file * Make sure hidden states for GPT2 have the same output shape in PT/TF * Minor fix to PT implementation of token classification loss * Skip loss equivalence test for TFHubert because it keeps overflowing to inf * Compute LM loss for TF the (weird) way it's computed in PT * Skip loss equivalence test for Wav2Vec2 for the same reason as Hubert * Fix - don't try to access the hidden states property when output is a tuple
This commit is contained in:
@@ -17,13 +17,15 @@
|
||||
import copy
|
||||
import inspect
|
||||
import math
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from transformers import is_tf_available
|
||||
from transformers.testing_utils import require_soundfile, require_tf, slow
|
||||
from transformers.testing_utils import is_pt_tf_cross_test, require_soundfile, require_tf, slow
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
@@ -333,6 +335,62 @@ class TFHubertModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@@ -458,6 +516,62 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Hubert models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFHubertUtilsTest(unittest.TestCase):
|
||||
|
||||
@@ -19,6 +19,8 @@ import glob
|
||||
import inspect
|
||||
import math
|
||||
import multiprocessing
|
||||
import os
|
||||
import tempfile
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
@@ -31,6 +33,7 @@ from transformers import Wav2Vec2Config, is_tf_available
|
||||
from transformers.testing_utils import (
|
||||
CaptureLogger,
|
||||
is_flaky,
|
||||
is_pt_tf_cross_test,
|
||||
require_librosa,
|
||||
require_pyctcdecode,
|
||||
require_tf,
|
||||
@@ -397,6 +400,62 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.Test
|
||||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
@@ -524,6 +583,62 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# TODO: (Amy) - check whether skipping CTC model resolves this issue and possible resolutions for CTC
|
||||
pass
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
# We override the base test here to skip loss calculation for Wav2Vec2 models because the loss is massive with
|
||||
# the default labels and frequently overflows to inf or exceeds numerical tolerances between TF/PT
|
||||
import torch
|
||||
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||
|
||||
tf_model = model_class(config)
|
||||
pt_model = pt_model_class(config)
|
||||
|
||||
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, tf_inputs_dict)
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFWav2Vec2UtilsTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user