Generate: smaller TF serving test (#18840)
This commit is contained in:
@@ -75,7 +75,7 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertConfig,
|
BertConfig,
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFAutoModelForSeq2SeqLM,
|
TFAutoModelForCausalLM,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
@@ -2180,8 +2180,8 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
def test_generate_tf_function_export(self):
|
def test_generate_tf_function_export(self):
|
||||||
test_model = TFAutoModelForSeq2SeqLM.from_pretrained("hf-internal-testing/tiny-random-t5")
|
test_model = TFAutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
max_length = 8
|
max_length = 2
|
||||||
|
|
||||||
class DummyModel(tf.Module):
|
class DummyModel(tf.Module):
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
@@ -2204,8 +2204,8 @@ class UtilsFunctionsTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
return {"sequences": outputs["sequences"]}
|
return {"sequences": outputs["sequences"]}
|
||||||
|
|
||||||
dummy_input_ids = [[2, 3, 4, 1, 0, 0, 0, 0], [102, 103, 104, 105, 1, 0, 0, 0]]
|
dummy_input_ids = [[2, 0], [102, 103]]
|
||||||
dummy_attention_masks = [[1, 1, 1, 1, 0, 0, 0, 0], [1, 1, 1, 1, 1, 0, 0, 0]]
|
dummy_attention_masks = [[1, 0], [1, 1]]
|
||||||
dummy_model = DummyModel(model=test_model)
|
dummy_model = DummyModel(model=test_model)
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
tf.saved_model.save(dummy_model, tmp_dir, signatures={"serving_default": dummy_model.serving})
|
||||||
|
|||||||
Reference in New Issue
Block a user