Updating the TensorFlow models to work as expected with tokenizers v3.0.0 (#3684)
* Updating modeling tf files; adding tests * Merge `encode_plus` and `batch_encode_plus`
This commit is contained in:
@@ -24,6 +24,7 @@ import tensorflow as tf
|
||||
from .configuration_ctrl import CTRLConfig
|
||||
from .file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from .modeling_tf_utils import TFPreTrainedModel, TFSharedEmbeddings, keras_serializable, shape_list
|
||||
from .tokenization_utils import BatchEncoding
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -230,7 +231,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
||||
head_mask = inputs[5] if len(inputs) > 5 else head_mask
|
||||
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
|
||||
assert len(inputs) <= 7, "Too many inputs."
|
||||
elif isinstance(inputs, dict):
|
||||
elif isinstance(inputs, (dict, BatchEncoding)):
|
||||
input_ids = inputs.get("input_ids")
|
||||
past = inputs.get("past", past)
|
||||
attention_mask = inputs.get("attention_mask", attention_mask)
|
||||
|
||||
Reference in New Issue
Block a user