Better TF docstring types (#23477)
* Rework TF type hints to use | None instead of Optional[] for tf.Tensor * Rework TF type hints to use | None instead of Optional[] for tf.Tensor * Don't forget the imports * Add the imports to tests too * make fixup * Refactor tests that depended on get_type_hints * Better test refactor * Fix an old hidden bug in the test_keras_fit input creation code * Fix for the Deit tests
This commit is contained in:
@@ -14,6 +14,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
@@ -22,10 +24,9 @@ import random
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from dataclasses import fields
|
||||
from importlib import import_module
|
||||
from math import isnan
|
||||
from typing import List, Tuple, get_type_hints
|
||||
from typing import List, Tuple
|
||||
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo
|
||||
@@ -140,26 +141,6 @@ def _config_zero_init(config):
|
||||
return configs_no_init
|
||||
|
||||
|
||||
def _return_type_has_loss(model):
|
||||
return_type = get_type_hints(model.call)
|
||||
if "return" not in return_type:
|
||||
return False
|
||||
return_type = return_type["return"]
|
||||
if hasattr(return_type, "__args__"): # Awkward check for union because UnionType only turns up in 3.10
|
||||
for type_annotation in return_type.__args__:
|
||||
if inspect.isclass(type_annotation) and issubclass(type_annotation, ModelOutput):
|
||||
field_names = [field.name for field in fields(type_annotation)]
|
||||
if "loss" in field_names:
|
||||
return True
|
||||
return False
|
||||
elif isinstance(return_type, tuple):
|
||||
return False
|
||||
elif isinstance(return_type, ModelOutput):
|
||||
class_fields = fields(return_type)
|
||||
return "loss" in class_fields
|
||||
return False
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFModelTesterMixin:
|
||||
model_tester = None
|
||||
@@ -1464,8 +1445,6 @@ class TFModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if not getattr(model, "hf_compute_loss", None) and not _return_type_has_loss(model):
|
||||
continue
|
||||
# The number of elements in the loss should be the same as the number of elements in the label
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
added_label_names = sorted(prepared_for_class.keys() - inputs_dict.keys(), reverse=True)
|
||||
@@ -1480,7 +1459,11 @@ class TFModelTesterMixin:
|
||||
input_name = possible_input_names.intersection(set(prepared_for_class)).pop()
|
||||
model_input = prepared_for_class.pop(input_name)
|
||||
|
||||
loss = model(model_input, **prepared_for_class)[0]
|
||||
outputs = model(model_input, **prepared_for_class)
|
||||
if not isinstance(outputs, ModelOutput) or not hasattr(outputs, "loss"):
|
||||
continue
|
||||
|
||||
loss = outputs.loss
|
||||
self.assertTrue(loss.shape.as_list() == expected_loss_size or loss.shape.as_list() == [1])
|
||||
|
||||
# Test that model correctly compute the loss when we mask some positions
|
||||
@@ -1540,18 +1523,16 @@ class TFModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
if not getattr(model, "hf_compute_loss", False) and not _return_type_has_loss(model):
|
||||
continue
|
||||
# Test that model correctly compute the loss with kwargs
|
||||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
|
||||
# Is there a better way to remove these decoder inputs?
|
||||
# We also remove "return_loss" as this is covered by the train_step when using fit()
|
||||
prepared_for_class = {
|
||||
key: val
|
||||
for key, val in prepared_for_class.items()
|
||||
if key
|
||||
not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "decoder_input_ids", "return_loss")
|
||||
if key not in ("head_mask", "decoder_head_mask", "cross_attn_head_mask", "return_loss")
|
||||
}
|
||||
if "labels" in prepared_for_class and "decoder_input_ids" in prepared_for_class:
|
||||
del prepared_for_class["decoder_input_ids"]
|
||||
|
||||
accuracy_classes = [
|
||||
"ForPreTraining",
|
||||
@@ -1575,8 +1556,10 @@ class TFModelTesterMixin:
|
||||
sample_weight = tf.convert_to_tensor([0.5] * self.model_tester.batch_size, dtype=tf.float32)
|
||||
else:
|
||||
sample_weight = None
|
||||
|
||||
model(model.dummy_inputs) # Build the model so we can get some constant weights
|
||||
# Build the model so we can get some constant weights and check outputs
|
||||
outputs = model(prepared_for_class)
|
||||
if getattr(outputs, "loss", None) is None:
|
||||
continue
|
||||
model_weights = model.get_weights()
|
||||
|
||||
# Run eagerly to save some expensive compilation times
|
||||
@@ -1648,7 +1631,6 @@ class TFModelTesterMixin:
|
||||
# Pass in all samples as a batch to match other `fit` calls
|
||||
weighted_dataset = weighted_dataset.batch(len(dataset))
|
||||
dataset = dataset.batch(len(dataset))
|
||||
|
||||
# Reinitialize to fix batchnorm again
|
||||
model.set_weights(model_weights)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user