TF: CTRL with native embedding layers (#23456)

This commit is contained in:
Joao Gante
2023-06-14 14:39:02 +01:00
committed by GitHub
parent eac8dede83
commit 4626df5077
2 changed files with 70 additions and 55 deletions

View File

@@ -225,6 +225,7 @@ class TFCTRLModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase
for model_class in self.all_model_classes:
model = model_class(config)
model.build() # may be needed for the get_bias() call below
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class in list_lm_models: