Better TF docstring types (#23477)

* Rework TF type hints to use | None instead of Optional[] for tf.Tensor

* Rework TF type hints to use | None instead of Optional[] for tf.Tensor

* Don't forget the imports

* Add the imports to tests too

* make fixup

* Refactor tests that depended on get_type_hints

* Better test refactor

* Fix an old hidden bug in the test_keras_fit input creation code

* Fix for the Deit tests
This commit is contained in:
Matt
2023-05-24 13:52:52 +01:00
committed by GitHub
parent 767e6b5314
commit f8b2574416
139 changed files with 2907 additions and 2621 deletions

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
import numpy as np

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import tempfile
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import AlbertConfig, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import tempfile
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available, is_torch_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import tempfile
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import BertConfig, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import BlenderbotConfig, BlenderbotTokenizer, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import BlenderbotSmallConfig, BlenderbotSmallTokenizer, is_tf_available

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow Blip model. """
from __future__ import annotations
import inspect
import tempfile
import unittest

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the TensorFlow Blip model. """
from __future__ import annotations
import unittest
import numpy as np

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow CLIP model. """
from __future__ import annotations
import inspect
import os
import tempfile

View File

@@ -12,6 +12,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import tempfile
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
""" Testing suite for the TensorFlow ConvNext model. """
from __future__ import annotations
import inspect
import unittest
from typing import List, Tuple

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import CTRLConfig, is_tf_available

View File

@@ -1,6 +1,8 @@
""" Testing suite for the Tensorflow CvT model. """
from __future__ import annotations
import inspect
import unittest
from math import floor

View File

@@ -14,6 +14,8 @@
# limitations under the License.
""" Testing suite for the TensorFlow Data2VecVision model. """
from __future__ import annotations
import collections.abc
import inspect
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import DebertaConfig, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import DebertaV2Config, is_tf_available

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow DeiT model. """
from __future__ import annotations
import inspect
import unittest
@@ -242,7 +244,7 @@ class TFDeiTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
if "labels" in inputs_dict and "labels" not in inspect.signature(model_class.call).parameters:
del inputs_dict["labels"]
return inputs_dict

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import DistilBertConfig, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import ElectraConfig, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import copy
import os
import tempfile

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import EsmConfig, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import FunnelConfig, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import GPT2Config, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import AutoTokenizer, GPTJConfig, is_tf_available

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow GroupViT model. """
from __future__ import annotations
import inspect
import os
import random

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import copy
import inspect
import math

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
import numpy as np

View File

@@ -14,6 +14,8 @@
# limitations under the License.
""" Testing suite for the TensorFlow LayoutLMv3 model. """
from __future__ import annotations
import copy
import inspect
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import LEDConfig, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import os
import tempfile
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import tempfile
import unittest
import warnings

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import tempfile
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import MobileBertConfig, is_tf_available

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow MobileViT model. """
from __future__ import annotations
import inspect
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import MPNetConfig, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import OpenAIGPTConfig, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
import numpy as np

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import tempfile
import unittest

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import json
import os
import shutil

View File

@@ -14,6 +14,8 @@
# limitations under the License.
""" Testing suite for the TensorFlow RegNet model. """
from __future__ import annotations
import inspect
import unittest
from typing import List, Tuple

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import RemBertConfig, is_tf_available

View File

@@ -15,6 +15,8 @@
""" Testing suite for the Tensorflow ResNet model. """
from __future__ import annotations
import inspect
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import RobertaConfig, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import RobertaPreLayerNormConfig, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import RoFormerConfig, is_tf_available

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow SAM model. """
from __future__ import annotations
import inspect
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
""" Testing suite for the TensorFlow SegFormer model. """
from __future__ import annotations
import inspect
import unittest
from typing import List, Tuple

View File

@@ -14,6 +14,8 @@
# limitations under the License.
""" Testing suite for the TensorFlow Speech2Text model. """
from __future__ import annotations
import inspect
import unittest

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TF 2.0 Swin model. """
from __future__ import annotations
import inspect
import unittest

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import T5Config, is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import random
import unittest

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow VisionEncoderDecoder model. """
from __future__ import annotations
import copy
import os
import tempfile

View File

@@ -15,6 +15,8 @@
""" Testing suite for the PyTorch VisionTextDualEncoder model. """
from __future__ import annotations
import collections
import tempfile
import unittest

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow ViT model. """
from __future__ import annotations
import inspect
import unittest

View File

@@ -15,6 +15,8 @@
""" Testing suite for the TensorFlow ViTMAE model. """
from __future__ import annotations
import copy
import inspect
import json

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import copy
import glob
import inspect

View File

@@ -14,6 +14,8 @@
# limitations under the License.
""" Testing suite for the TensorFlow Whisper model. """
from __future__ import annotations
import inspect
import tempfile
import traceback

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import XGLMConfig, XGLMTokenizer, is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from transformers import is_tf_available

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import inspect
import random
import unittest

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import copy
import inspect
import json
@@ -22,10 +24,9 @@ import random
import tempfile
import unittest
import unittest.mock as mock
from dataclasses import fields
from importlib import import_module
from math import isnan
from typing import List, Tuple, get_type_hints
from typing import List, Tuple
from datasets import Dataset
from huggingface_hub import HfFolder, Repository, delete_repo
@@ -140,26 +141,6 @@ def _config_zero_init(config):
return configs_no_init
def _return_type_has_loss(model):
return_type = get_type_hints(model.call)
if "return" not in return_type:
return False
return_type = return_type["return"]
if hasattr(return_type, "__args__"): # Awkward check for union because UnionType only turns up in 3.10
for type_annotation in return_type.__args__:
if inspect.isclass(type_annotation) and issubclass(type_annotation, ModelOutput):
field_names = [field.name for field in fields(type_annotation)]
if "loss" in field_names:
return True
return False
elif isinstance(return_type, tuple):
return False
elif isinstance(return_type, ModelOutput):
class_fields = fields(return_type)
return "loss" in class_fields
return False
@require_tf
class TFModelTesterMixin:
model_tester = None
@@ -1464,8 +1445,6 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if not getattr(model, "hf_compute_loss", None) and not _return_type_has_loss(model):
continue
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label_names = sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)
@@ -1480,7 +1459,11 @@ class TFModelTesterMixin:
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
model_input = prepared_for_class.pop(input_name)
loss = model(model_input, **prepared_for_class)[0]
outputs = model(model_input, **prepared_for_class)
if not isinstance(outputs, ModelOutput) or not hasattr(outputs, "loss"):
continue
loss = outputs.loss
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
# Test that model correctly compute the loss when we mask some positions
@@ -1540,18 +1523,16 @@ class TFModelTesterMixin:
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if not getattr(model, "hf_compute_loss", False) and not _return_type_has_loss(model):
continue
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Is there a better way to remove these decoder inputs?
# We also remove "return_loss" as this is covered by the train_step when using fit()
prepared_for_class = {
key: val
for key, val in prepared_for_class.items()
if key
not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids", "return_loss")
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "return_loss")
}
if "labels" in prepared_for_class and "decoder_input_ids" in prepared_for_class:
del prepared_for_class["decoder_input_ids"]
accuracy_classes = [
"ForPreTraining",
@@ -1575,8 +1556,10 @@ class TFModelTesterMixin:
sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
else:
sample_weight = None
model(model.dummy_inputs) # Build the model so we can get some constant weights
# Build the model so we can get some constant weights and check outputs
outputs = model(prepared_for_class)
if getattr(outputs, "loss", None) is None:
continue
model_weights = model.get_weights()
# Run eagerly to save some expensive compilation times
@@ -1648,7 +1631,6 @@ class TFModelTesterMixin:
# Pass in all samples as a batch to match other `fit` calls
weighted_dataset = weighted_dataset.batch(len(dataset))
dataset = dataset.batch(len(dataset))
# Reinitialize to fix batchnorm again
model.set_weights(model_weights)

View File

@@ -14,6 +14,8 @@
# limitations under the License.
from __future__ import annotations
import copy
import os
import tempfile