Better TF docstring types (#23477)
* Rework TF type hints to use | None instead of Optional[] for tf.Tensor * Rework TF type hints to use | None instead of Optional[] for tf.Tensor * Don't forget the imports * Add the imports to tests too * make fixup * Refactor tests that depended on get_type_hints * Better test refactor * Fix an old hidden bug in the test_keras_fit input creation code * Fix for the Deit tests
This commit is contained in:
@@ -386,9 +386,9 @@ class TF{{cookiecutter.camelcase_modelname}}Layer(tf.keras.layers.Layer):
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
encoder_hidden_states: Optional[tf.Tensor],
|
||||
encoder_attention_mask: Optional[tf.Tensor],
|
||||
past_key_value: Optional[Tuple[tf.Tensor]],
|
||||
encoder_hidden_states: tf.Tensor | None,
|
||||
encoder_attention_mask: tf.Tensor | None,
|
||||
past_key_value: Tuple[tf.Tensor] | None,
|
||||
output_attentions: bool,
|
||||
training: bool = False,
|
||||
) -> Tuple[tf.Tensor]:
|
||||
@@ -465,9 +465,9 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
hidden_states: tf.Tensor,
|
||||
attention_mask: tf.Tensor,
|
||||
head_mask: tf.Tensor,
|
||||
encoder_hidden_states: Optional[tf.Tensor],
|
||||
encoder_attention_mask: Optional[tf.Tensor],
|
||||
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
|
||||
encoder_hidden_states: tf.Tensor | None,
|
||||
encoder_attention_mask: tf.Tensor | None,
|
||||
past_key_values: Tuple[Tuple[tf.Tensor]] | None,
|
||||
use_cache: Optional[bool],
|
||||
output_attentions: bool,
|
||||
output_hidden_states: bool,
|
||||
@@ -639,14 +639,14 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
|
||||
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@@ -937,14 +937,14 @@ class TF{{cookiecutter.camelcase_modelname}}Model(TF{{cookiecutter.camelcase_mod
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
|
||||
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
@@ -1038,16 +1038,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForMaskedLM(TF{{cookiecutter.camelca
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
labels: np.ndarray | tf.Tensor | None = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
@@ -1129,20 +1129,20 @@ class TF{{cookiecutter.camelcase_modelname}}ForCausalLM(TF{{cookiecutter.camelca
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
encoder_hidden_states: np.ndarray | tf.Tensor | None = None,
|
||||
encoder_attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
labels: np.ndarray | tf.Tensor | None = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
@@ -1274,16 +1274,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification(TF{{cookie
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
labels: np.ndarray | tf.Tensor | None = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
@@ -1362,16 +1362,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice(TF{{cookiecutter.c
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
labels: np.ndarray | tf.Tensor | None = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
@@ -1487,16 +1487,16 @@ class TF{{cookiecutter.camelcase_modelname}}ForTokenClassification(TF{{cookiecut
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
labels: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
labels: np.ndarray | tf.Tensor | None = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
@@ -1566,17 +1566,17 @@ class TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering(TF{{cookiecutte
|
||||
)
|
||||
def call(
|
||||
self,
|
||||
input_ids: Optional[TFModelInputType] = None,
|
||||
attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
input_ids: TFModelInputType | None = None,
|
||||
attention_mask: np.ndarray | tf.Tensor | None = None,
|
||||
token_type_ids: np.ndarray | tf.Tensor | None = None,
|
||||
position_ids: np.ndarray | tf.Tensor | None = None,
|
||||
head_mask: np.ndarray | tf.Tensor | None = None,
|
||||
inputs_embeds: np.ndarray | tf.Tensor | None = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None,
|
||||
start_positions: np.ndarray | tf.Tensor | None = None,
|
||||
end_positions: np.ndarray | tf.Tensor | None = None,
|
||||
training: Optional[bool] = False,
|
||||
) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]:
|
||||
r"""
|
||||
@@ -1777,12 +1777,12 @@ class TF{{cookiecutter.camelcase_modelname}}Attention(tf.keras.layers.Layer):
|
||||
def call(
|
||||
self,
|
||||
hidden_states: tf.Tensor,
|
||||
key_value_states: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[Tuple[tf.Tensor]]] = None,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
key_value_states: tf.Tensor | None = None,
|
||||
past_key_value: Tuple[Tuple[tf.Tensor]] | None = None,
|
||||
attention_mask: tf.Tensor | None = None,
|
||||
layer_head_mask: tf.Tensor | None = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, Optional[tf.Tensor]]:
|
||||
) -> Tuple[tf.Tensor, tf.Tensor | None]:
|
||||
"""Input shape: Batch x Time x Channel"""
|
||||
|
||||
# if key_value_states are provided this layer is used as a cross-attention layer
|
||||
@@ -1962,12 +1962,12 @@ class TF{{cookiecutter.camelcase_modelname}}DecoderLayer(tf.keras.layers.Layer):
|
||||
def call(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask: Optional[tf.Tensor] = None,
|
||||
encoder_hidden_states: Optional[tf.Tensor] = None,
|
||||
encoder_attention_mask: Optional[tf.Tensor] = None,
|
||||
layer_head_mask: Optional[tf.Tensor] = None,
|
||||
cross_attn_layer_head_mask: Optional[tf.Tensor] = None,
|
||||
past_key_value: Optional[Tuple[tf.Tensor]] = None,
|
||||
attention_mask: tf.Tensor | None = None,
|
||||
encoder_hidden_states: tf.Tensor | None = None,
|
||||
encoder_attention_mask: tf.Tensor | None = None,
|
||||
layer_head_mask: tf.Tensor | None = None,
|
||||
cross_attn_layer_head_mask: tf.Tensor | None = None,
|
||||
past_key_value: Tuple[tf.Tensor] | None = None,
|
||||
training=False,
|
||||
) -> Tuple[tf.Tensor, tf.Tensor, Tuple[Tuple[tf.Tensor]]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user