Fix train_step, test_step and tests for CLIP (#18684)

* Fix train_step and test_step, correctly enable CLIP fit test

* Stop using get_args on older Python versions

* Don't use get_origin either

* UnionType is actually even newer, don't use that either

* Apply the same fix to test_loss_computation

* Just realized I was accidentally skipping a bunch of tests!

* Fix test_loss_computation for models without separable labels

* Fix scalar losses in test_step and train_step

* Stop committing your breakpoints

* Fix Swin loss shape

* Fix Tapas loss shape

* Shape fixes for TAPAS, DeIT, HuBERT and ViTMAE

* Add loss computation to TFMobileBertForPreTraining

* make fixup and move copied from statement

* make fixup and move copied from statement

* Correct copied from

* Add labels and next_sentence_label inputs to TFMobileBERT

* Make sure total_loss is always defined

* Update tests/test_modeling_tf_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Fix copied from

* Ensure CTC models get labels in tests

* Ensure CTC models get labels in tests

* Fix tests for vit_mae

* Fix tests for vit_mae

* Fix tests for vit_mae

* Reduce batch size for wav2vec2 testing because it was causing OOM

* Skip some TAPAS tests that are failing

* Skip a failing HuBERT test

* make style

* Fix mobilebertforpretraining test

* Skip Wav2Vec2 tests that use huge amounts of mem

* Skip keras_fit for Wav2Vec2 as well

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt
2022-09-09 20:01:02 +01:00
committed by GitHub
parent f1a6df3210
commit 660e0b97bd
13 changed files with 294 additions and 162 deletions

View File

@@ -17,6 +17,7 @@
import unittest
from transformers import MobileBertConfig, is_tf_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_tf, slow, tooslow
from ...test_configuration_common import ConfigTester
@@ -27,6 +28,7 @@ if is_tf_available():
import tensorflow as tf
from transformers import (
TF_MODEL_FOR_PRETRAINING_MAPPING,
TFMobileBertForMaskedLM,
TFMobileBertForMultipleChoice,
TFMobileBertForNextSentencePrediction,
@@ -58,6 +60,16 @@ class TFMobileBertModelTest(TFModelTesterMixin, unittest.TestCase):
test_head_masking = False
test_onnx = False
# special case for ForPreTraining model, same as BERT tests
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
if return_labels:
if model_class in get_values(TF_MODEL_FOR_PRETRAINING_MAPPING):
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
return inputs_dict
class TFMobileBertModelTester(object):
def __init__(
self,