[TF] Fix creating a PR while pushing in TF framework (#21968)
* add create pr arg * style * add test * ficup * update test * last nit fix typo * add `is_pt_tf_cross_test` marker for the tsts
This commit is contained in:
@@ -85,6 +85,7 @@ if is_tf_available():
|
||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
BertConfig,
|
||||
PreTrainedModel,
|
||||
PushToHubCallback,
|
||||
RagRetriever,
|
||||
TFAutoModel,
|
||||
@@ -92,6 +93,7 @@ if is_tf_available():
|
||||
TFBertForMaskedLM,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFPreTrainedModel,
|
||||
TFRagModel,
|
||||
TFSharedEmbeddings,
|
||||
)
|
||||
@@ -2466,6 +2468,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_push_to_hub_callback(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
@@ -2489,6 +2492,12 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
|
||||
tf_push_to_hub_params.pop("base_model_card_args")
|
||||
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
|
||||
pt_push_to_hub_params.pop("deprecated_kwargs")
|
||||
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
|
||||
Reference in New Issue
Block a user