[TF] Fix creating a PR while pushing in TF framework (#21968)

* add create pr arg

* style

* add test

* ficup

* update test

* last nit fix typo

* add `is_pt_tf_cross_test` marker for the tsts
This commit is contained in:
Arthur
2023-03-07 17:32:08 +01:00
committed by GitHub
parent d128f2ffab
commit 2156662dea
2 changed files with 25 additions and 10 deletions

View File

@@ -85,6 +85,7 @@ if is_tf_available():
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
PreTrainedModel,
PushToHubCallback,
RagRetriever,
TFAutoModel,
@@ -92,6 +93,7 @@ if is_tf_available():
TFBertForMaskedLM,
TFBertForSequenceClassification,
TFBertModel,
TFPreTrainedModel,
TFRagModel,
TFSharedEmbeddings,
)
@@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase):
break
self.assertTrue(models_equal)
@is_pt_tf_cross_test
def test_push_to_hub_callback(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
@@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase):
break
self.assertTrue(models_equal)
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
tf_push_to_hub_params.pop("base_model_card_args")
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
pt_push_to_hub_params.pop("deprecated_kwargs")
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
def test_push_to_hub_in_organization(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37