Add tf_keras imports to prepare for Keras 3 (#28588)
* Port core files + ESM (because ESM code is odd) * Search-replace in modelling code * Fix up transfo_xl as well * Fix other core files + tests (still need to add correct import to tests) * Fix cookiecutter * make fixup, fix imports in some more core files * Auto-add imports to tests * Cleanup, add imports to sagemaker tests * Use correct exception for importing tf_keras * Fixes in modeling_tf_utils * make fixup * Correct version parsing code * Ensure the pipeline tests correctly revert to float32 after each test * Ensure the pipeline tests correctly revert to float32 after each test * More tf.keras -> keras * Add dtype cast * Better imports of tf_keras * Add a cast for tf.assign, just in case * Fix callback imports
This commit is contained in:
@@ -43,6 +43,7 @@ if is_tf_available():
|
||||
TFMinLengthLogitsProcessor,
|
||||
tf_top_k_top_p_filtering,
|
||||
)
|
||||
from transformers.modeling_tf_utils import keras
|
||||
|
||||
if is_tensorflow_text_available():
|
||||
import tensorflow_text as text
|
||||
@@ -254,7 +255,7 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
# file needed to load the TF tokenizer
|
||||
hf_hub_download(repo_id="google/flan-t5-small", filename="spiece.model", local_dir=tmp_dir)
|
||||
|
||||
class CompleteSentenceTransformer(tf.keras.layers.Layer):
|
||||
class CompleteSentenceTransformer(keras.layers.Layer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.tokenizer = text.SentencepieceTokenizer(
|
||||
@@ -271,9 +272,9 @@ class TFGenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTests
|
||||
return self.tokenizer.detokenize(outputs)
|
||||
|
||||
complete_model = CompleteSentenceTransformer()
|
||||
inputs = tf.keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
|
||||
inputs = keras.layers.Input(shape=(1,), dtype=tf.string, name="inputs")
|
||||
outputs = complete_model(inputs)
|
||||
keras_model = tf.keras.Model(inputs, outputs)
|
||||
keras_model = keras.Model(inputs, outputs)
|
||||
keras_model.save(tmp_dir)
|
||||
|
||||
def test_eos_token_id_int_and_list_top_k_top_sampling(self):
|
||||
|
||||
Reference in New Issue
Block a user