diff --git a/docs/source/en/model_doc/xlm-roberta.mdx b/docs/source/en/model_doc/xlm-roberta.mdx index 941feac9d3..fbd98284fb 100644 --- a/docs/source/en/model_doc/xlm-roberta.mdx +++ b/docs/source/en/model_doc/xlm-roberta.mdx @@ -146,6 +146,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] TFXLMRobertaModel - call +## TFXLMRobertaForCausalLM + +[[autodoc]] TFXLMRobertaForCausalLM + - call + ## TFXLMRobertaForMaskedLM [[autodoc]] TFXLMRobertaForMaskedLM @@ -176,6 +181,11 @@ A list of official Hugging Face and community (indicated by 🌎) resources to h [[autodoc]] FlaxXLMRobertaModel - __call__ +## FlaxXLMRobertaForCausalLM + +[[autodoc]] FlaxXLMRobertaForCausalLM + - __call__ + ## FlaxXLMRobertaForMaskedLM [[autodoc]] FlaxXLMRobertaForMaskedLM diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 8e91bacb26..62f10ace2f 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3153,12 +3153,14 @@ else: _import_structure["models.xlm_roberta"].extend( [ "TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFXLMRobertaForCausalLM", "TFXLMRobertaForMaskedLM", "TFXLMRobertaForMultipleChoice", "TFXLMRobertaForQuestionAnswering", "TFXLMRobertaForSequenceClassification", "TFXLMRobertaForTokenClassification", "TFXLMRobertaModel", + "TFXLMRobertaPreTrainedModel", ] ) _import_structure["models.xlnet"].extend( @@ -3435,12 +3437,15 @@ else: ) _import_structure["models.xlm_roberta"].extend( [ + "FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "FlaxXLMRobertaForMaskedLM", "FlaxXLMRobertaForMultipleChoice", "FlaxXLMRobertaForQuestionAnswering", "FlaxXLMRobertaForSequenceClassification", "FlaxXLMRobertaForTokenClassification", "FlaxXLMRobertaModel", + "FlaxXLMRobertaForCausalLM", + "FlaxXLMRobertaPreTrainedModel", ] ) @@ -6022,12 +6027,14 @@ if TYPE_CHECKING: ) from .models.xlm_roberta import ( TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFXLMRobertaForCausalLM, TFXLMRobertaForMaskedLM, TFXLMRobertaForMultipleChoice, TFXLMRobertaForQuestionAnswering, TFXLMRobertaForSequenceClassification, TFXLMRobertaForTokenClassification, TFXLMRobertaModel, + TFXLMRobertaPreTrainedModel, ) from .models.xlnet import ( TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -6240,12 +6247,15 @@ if TYPE_CHECKING: ) from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel from .models.xlm_roberta import ( + FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + FlaxXLMRobertaForCausalLM, FlaxXLMRobertaForMaskedLM, FlaxXLMRobertaForMultipleChoice, FlaxXLMRobertaForQuestionAnswering, FlaxXLMRobertaForSequenceClassification, FlaxXLMRobertaForTokenClassification, FlaxXLMRobertaModel, + FlaxXLMRobertaPreTrainedModel, ) else: diff --git a/src/transformers/models/auto/modeling_flax_auto.py b/src/transformers/models/auto/modeling_flax_auto.py index 2335b31728..61d34f0f08 100644 --- a/src/transformers/models/auto/modeling_flax_auto.py +++ b/src/transformers/models/auto/modeling_flax_auto.py @@ -142,6 +142,7 @@ FLAX_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("roberta", "FlaxRobertaForCausalLM"), ("roberta-prelayernorm", "FlaxRobertaPreLayerNormForCausalLM"), ("xglm", "FlaxXGLMForCausalLM"), + ("xlm-roberta", "FlaxXLMRobertaForCausalLM"), ] ) diff --git a/src/transformers/models/auto/modeling_tf_auto.py b/src/transformers/models/auto/modeling_tf_auto.py index c77fba4f66..fcb8c5e5d5 100644 --- a/src/transformers/models/auto/modeling_tf_auto.py +++ b/src/transformers/models/auto/modeling_tf_auto.py @@ -180,6 +180,7 @@ TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES = OrderedDict( ("transfo-xl", "TFTransfoXLLMHeadModel"), ("xglm", "TFXGLMForCausalLM"), ("xlm", "TFXLMWithLMHeadModel"), + ("xlm-roberta", "TFXLMRobertaForCausalLM"), ("xlnet", "TFXLNetLMHeadModel"), ] ) diff --git a/src/transformers/models/xlm_roberta/__init__.py b/src/transformers/models/xlm_roberta/__init__.py index 2a2abf6f61..9946cf4947 100644 --- a/src/transformers/models/xlm_roberta/__init__.py +++ b/src/transformers/models/xlm_roberta/__init__.py @@ -79,12 +79,14 @@ except OptionalDependencyNotAvailable: else: _import_structure["modeling_tf_xlm_roberta"] = [ "TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFXLMRobertaForCausalLM", "TFXLMRobertaForMaskedLM", "TFXLMRobertaForMultipleChoice", "TFXLMRobertaForQuestionAnswering", "TFXLMRobertaForSequenceClassification", "TFXLMRobertaForTokenClassification", "TFXLMRobertaModel", + "TFXLMRobertaPreTrainedModel", ] try: @@ -94,12 +96,15 @@ except OptionalDependencyNotAvailable: pass else: _import_structure["modeling_flax_xlm_roberta"] = [ + "FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", "FlaxXLMRobertaForMaskedLM", + "FlaxXLMRobertaForCausalLM", "FlaxXLMRobertaForMultipleChoice", "FlaxXLMRobertaForQuestionAnswering", "FlaxXLMRobertaForSequenceClassification", "FlaxXLMRobertaForTokenClassification", "FlaxXLMRobertaModel", + "FlaxXLMRobertaPreTrainedModel", ] if TYPE_CHECKING: @@ -151,12 +156,14 @@ if TYPE_CHECKING: else: from .modeling_tf_xlm_roberta import ( TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFXLMRobertaForCausalLM, TFXLMRobertaForMaskedLM, TFXLMRobertaForMultipleChoice, TFXLMRobertaForQuestionAnswering, TFXLMRobertaForSequenceClassification, TFXLMRobertaForTokenClassification, TFXLMRobertaModel, + TFXLMRobertaPreTrainedModel, ) try: @@ -166,12 +173,15 @@ if TYPE_CHECKING: pass else: from .modeling_flax_xlm_roberta import ( + FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + FlaxXLMRobertaForCausalLM, FlaxXLMRobertaForMaskedLM, FlaxXLMRobertaForMultipleChoice, FlaxXLMRobertaForQuestionAnswering, FlaxXLMRobertaForSequenceClassification, FlaxXLMRobertaForTokenClassification, FlaxXLMRobertaModel, + FlaxXLMRobertaPreTrainedModel, ) else: diff --git a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py index 19c25346b6..6b86ac47af 100644 --- a/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py @@ -15,31 +15,78 @@ # limitations under the License. """Flax XLM-RoBERTa model.""" -from ...utils import add_start_docstrings, logging -from ..roberta.modeling_flax_roberta import ( - FlaxRobertaForMaskedLM, - FlaxRobertaForMultipleChoice, - FlaxRobertaForQuestionAnswering, - FlaxRobertaForSequenceClassification, - FlaxRobertaForTokenClassification, - FlaxRobertaModel, +from typing import Callable, Optional, Tuple + +import numpy as np + +import flax.linen as nn +import jax +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict, freeze, unfreeze +from flax.linen import combine_masks, make_causal_mask +from flax.linen import partitioning as nn_partitioning +from flax.linen.attention import dot_product_attention_weights +from flax.traverse_util import flatten_dict, unflatten_dict +from jax import lax + +from ...modeling_flax_outputs import ( + FlaxBaseModelOutputWithPastAndCrossAttentions, + FlaxBaseModelOutputWithPooling, + FlaxBaseModelOutputWithPoolingAndCrossAttentions, + FlaxCausalLMOutputWithCrossAttentions, + FlaxMaskedLMOutput, + FlaxMultipleChoiceModelOutput, + FlaxQuestionAnsweringModelOutput, + FlaxSequenceClassifierOutput, + FlaxTokenClassifierOutput, ) +from ...modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring, overwrite_call_docstring +from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging from .configuration_xlm_roberta import XLMRobertaConfig logger = logging.get_logger(__name__) -XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ +_CHECKPOINT_FOR_DOC = "xlm-roberta-base" +_CONFIG_FOR_DOC = "XLMRobertaConfig" +_TOKENIZER_FOR_DOC = "XLMRobertaTokenizer" + +remat = nn_partitioning.remat + +FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ "xlm-roberta-base", "xlm-roberta-large", - "xlm-roberta-large-finetuned-conll02-dutch", - "xlm-roberta-large-finetuned-conll02-spanish", - "xlm-roberta-large-finetuned-conll03-english", - "xlm-roberta-large-finetuned-conll03-german", # See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta ] + +# Copied from transformers.models.roberta.modeling_flax_roberta.create_position_ids_from_input_ids +def create_position_ids_from_input_ids(input_ids, padding_idx): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: jnp.ndarray + padding_idx: int + + Returns: jnp.ndarray + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = (input_ids != padding_idx).astype("i4") + + if mask.ndim > 2: + mask = mask.reshape((-1, mask.shape[-1])) + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + incremental_indices = incremental_indices.reshape(input_ids.shape) + else: + incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask + + return incremental_indices.astype("i4") + padding_idx + + XLM_ROBERTA_START_DOCSTRING = r""" + This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the library implements for all its model (such as downloading, saving and converting weights from PyTorch models) @@ -60,92 +107,1408 @@ XLM_ROBERTA_START_DOCSTRING = r""" configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights. """ +XLM_ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`numpy.ndarray` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. -@add_start_docstrings( - "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - XLM_ROBERTA_START_DOCSTRING, -) -class FlaxXLMRobertaModel(FlaxRobertaModel): + Indices can be obtained using [`BertTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`numpy.ndarray` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`numpy.ndarray` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + head_mask (`numpy.ndarray` of shape `({0})`, `optional): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->XLMRoberta +class FlaxXLMRobertaEmbeddings(nn.Module): + """Construct the embeddings from word, position and token_type embeddings.""" + + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.word_embeddings = nn.Embed( + self.config.vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.position_embeddings = nn.Embed( + self.config.max_position_embeddings, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.token_type_embeddings = nn.Embed( + self.config.type_vocab_size, + self.config.hidden_size, + embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): + # Embed + inputs_embeds = self.word_embeddings(input_ids.astype("i4")) + position_embeds = self.position_embeddings(position_ids.astype("i4")) + token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4")) + + # Sum all embeddings + hidden_states = inputs_embeds + token_type_embeddings + position_embeds + + # Layer Norm + hidden_states = self.LayerNorm(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->XLMRoberta +class FlaxXLMRobertaSelfAttention(nn.Module): + config: XLMRobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.head_dim = self.config.hidden_size // self.config.num_attention_heads + if self.config.hidden_size % self.config.num_attention_heads != 0: + raise ValueError( + "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` " + " : {self.config.num_attention_heads}" + ) + + self.query = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.key = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.value = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + if self.causal: + self.causal_mask = make_causal_mask( + jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool" + ) + + def _split_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim)) + + def _merge_heads(self, hidden_states): + return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,)) + + @nn.compact + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache + def _concatenate_to_cache(self, key, value, query, attention_mask): + """ + This function takes projected key, value states from a single input token and concatenates the states to cached + states from previous steps. This function is slighly adapted from the official Flax repository: + https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252 + """ + # detect if we're initializing by absence of existing cache data. + is_initialized = self.has_variable("cache", "cached_key") + cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype) + cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype) + cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32)) + + if is_initialized: + *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape + # update key, value caches with our new 1d spatial slices + cur_index = cache_index.value + indices = (0,) * len(batch_dims) + (cur_index, 0, 0) + key = lax.dynamic_update_slice(cached_key.value, key, indices) + value = lax.dynamic_update_slice(cached_value.value, value, indices) + cached_key.value = key + cached_value.value = value + num_updated_cache_vectors = query.shape[1] + cache_index.value = cache_index.value + num_updated_cache_vectors + # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements. + pad_mask = jnp.broadcast_to( + jnp.arange(max_length) < cur_index + num_updated_cache_vectors, + tuple(batch_dims) + (1, num_updated_cache_vectors, max_length), + ) + attention_mask = combine_masks(pad_mask, attention_mask) + return key, value, attention_mask + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states: Optional[jnp.array] = None, + init_cache: bool = False, + deterministic=True, + output_attentions: bool = False, + ): + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + batch_size = hidden_states.shape[0] + + # get query proj + query_states = self.query(hidden_states) + # get key, value proj + if is_cross_attention: + # cross_attentions + key_states = self.key(key_value_states) + value_states = self.value(key_value_states) + else: + # self_attention + key_states = self.key(hidden_states) + value_states = self.value(hidden_states) + + query_states = self._split_heads(query_states) + key_states = self._split_heads(key_states) + value_states = self._split_heads(value_states) + + # handle cache prepare causal attention mask + if self.causal: + query_length, key_length = query_states.shape[1], key_states.shape[1] + if self.has_variable("cache", "cached_key"): + mask_shift = self.variables["cache"]["cache_index"] + max_decoder_length = self.variables["cache"]["cached_key"].shape[1] + causal_mask = lax.dynamic_slice( + self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length) + ) + else: + causal_mask = self.causal_mask[:, :, :query_length, :key_length] + causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:]) + + # combine masks if needed + if attention_mask is not None and self.causal: + attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape) + attention_mask = combine_masks(attention_mask, causal_mask) + elif self.causal: + attention_mask = causal_mask + elif attention_mask is not None: + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) + + # During fast autoregressive decoding, we feed one position at a time, + # and cache the keys and values step by step. + if self.causal and (self.has_variable("cache", "cached_key") or init_cache): + key_states, value_states, attention_mask = self._concatenate_to_cache( + key_states, value_states, query_states, attention_mask + ) + + # Convert the boolean attention mask to an attention bias. + if attention_mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + attention_mask > 0, + jnp.full(attention_mask.shape, 0.0).astype(self.dtype), + jnp.full(attention_mask.shape, -1e10).astype(self.dtype), + ) + else: + attention_bias = None + + dropout_rng = None + if not deterministic and self.config.attention_probs_dropout_prob > 0.0: + dropout_rng = self.make_rng("dropout") + + attn_weights = dot_product_attention_weights( + query_states, + key_states, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.config.attention_probs_dropout_prob, + broadcast_dropout=True, + deterministic=deterministic, + dtype=self.dtype, + precision=None, + ) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask) + + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->XLMRoberta +class FlaxXLMRobertaSelfOutput(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + + def __call__(self, hidden_states, input_tensor, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->XLMRoberta +class FlaxXLMRobertaAttention(nn.Module): + config: XLMRobertaConfig + causal: bool = False + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.self = FlaxXLMRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype) + self.output = FlaxXLMRobertaSelfOutput(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + key_value_states=None, + init_cache=False, + deterministic=True, + output_attentions: bool = False, + ): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attn_outputs = self.self( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=key_value_states, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attn_output = attn_outputs[0] + hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_outputs[1],) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->XLMRoberta +class FlaxXLMRobertaIntermediate(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.intermediate_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.activation = ACT2FN[self.config.hidden_act] + + def __call__(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.activation(hidden_states) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->XLMRoberta +class FlaxXLMRobertaOutput(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + + def __call__(self, hidden_states, attention_output, deterministic: bool = True): + hidden_states = self.dense(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.LayerNorm(hidden_states + attention_output) + return hidden_states + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->XLMRoberta +class FlaxXLMRobertaLayer(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.attention = FlaxXLMRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype) + self.intermediate = FlaxXLMRobertaIntermediate(self.config, dtype=self.dtype) + self.output = FlaxXLMRobertaOutput(self.config, dtype=self.dtype) + if self.config.add_cross_attention: + self.crossattention = FlaxXLMRobertaAttention(self.config, causal=False, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + layer_head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + ): + # Self Attention + attention_outputs = self.attention( + hidden_states, + attention_mask, + layer_head_mask=layer_head_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = attention_outputs[0] + + # Cross-Attention Block + if encoder_hidden_states is not None: + cross_attention_outputs = self.crossattention( + attention_output, + attention_mask=encoder_attention_mask, + layer_head_mask=layer_head_mask, + key_value_states=encoder_hidden_states, + deterministic=deterministic, + output_attentions=output_attentions, + ) + attention_output = cross_attention_outputs[0] + + hidden_states = self.intermediate(attention_output) + hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attention_outputs[1],) + if encoder_hidden_states is not None: + outputs += (cross_attention_outputs[1],) + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->XLMRoberta +class FlaxXLMRobertaLayerCollection(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + if self.gradient_checkpointing: + FlaxXLMRobertaCheckpointLayer = remat(FlaxXLMRobertaLayer, static_argnums=(5, 6, 7)) + self.layers = [ + FlaxXLMRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + else: + self.layers = [ + FlaxXLMRobertaLayer(self.config, name=str(i), dtype=self.dtype) + for i in range(self.config.num_hidden_layers) + ] + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None + + # Check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.shape[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for " + f" {head_mask.shape[0]}." + ) + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = layer( + hidden_states, + attention_mask, + head_mask[i] if head_mask is not None else None, + encoder_hidden_states, + encoder_attention_mask, + init_cache, + deterministic, + output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + if encoder_hidden_states is not None: + all_cross_attentions += (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions) + + if not return_dict: + return tuple(v for v in outputs if v is not None) + + return FlaxBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->XLMRoberta +class FlaxXLMRobertaEncoder(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + gradient_checkpointing: bool = False + + def setup(self): + self.layer = FlaxXLMRobertaLayerCollection( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + + def __call__( + self, + hidden_states, + attention_mask, + head_mask, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + return self.layer( + hidden_states, + attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->XLMRoberta +class FlaxXLMRobertaPooler(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + dtype=self.dtype, + ) + + def __call__(self, hidden_states): + cls_hidden_state = hidden_states[:, 0] + cls_hidden_state = self.dense(cls_hidden_state) + return nn.tanh(cls_hidden_state) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaLMHead with Roberta->XLMRoberta +class FlaxXLMRobertaLMHead(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) + self.decoder = nn.Dense( + self.config.vocab_size, + dtype=self.dtype, + use_bias=False, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,)) + + def __call__(self, hidden_states, shared_embedding=None): + hidden_states = self.dense(hidden_states) + hidden_states = ACT2FN["gelu"](hidden_states) + hidden_states = self.layer_norm(hidden_states) + + if shared_embedding is not None: + hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states) + else: + hidden_states = self.decoder(hidden_states) + + bias = jnp.asarray(self.bias, self.dtype) + hidden_states += bias + return hidden_states + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaClassificationHead with Roberta->XLMRoberta +class FlaxXLMRobertaClassificationHead(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + + def setup(self): + self.dense = nn.Dense( + self.config.hidden_size, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.out_proj = nn.Dense( + self.config.num_labels, + dtype=self.dtype, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range), + ) + + def __call__(self, hidden_states, deterministic=True): + hidden_states = hidden_states[:, 0, :] # take token (equiv. to [CLS]) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.dense(hidden_states) + hidden_states = nn.tanh(hidden_states) + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaPreTrainedModel with Roberta->XLMRoberta, roberta->xlm-roberta, ROBERTA->XLM_ROBERTA +class FlaxXLMRobertaPreTrainedModel(FlaxPreTrainedModel): """ - This class overrides [`FlaxRobertaModel`]. Please check the superclass for the appropriate documentation alongside - usage examples. + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. """ config_class = XLMRobertaConfig + base_model_prefix = "xlm-roberta" + + module_class: nn.Module = None + + def __init__( + self, + config: XLMRobertaConfig, + input_shape: Tuple = (1, 1), + seed: int = 0, + dtype: jnp.dtype = jnp.float32, + _do_init: bool = True, + gradient_checkpointing: bool = False, + **kwargs + ): + module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs) + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) + + # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing + def enable_gradient_checkpointing(self): + self._module = self.module_class( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=True, + ) + + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + # init input tensors + input_ids = jnp.zeros(input_shape, dtype="i4") + token_type_ids = jnp.ones_like(input_ids) + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + attention_mask = jnp.ones_like(input_ids) + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + + if self.config.add_cross_attention: + encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,)) + encoder_attention_mask = attention_mask + module_init_outputs = self.module.init( + rngs, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states, + encoder_attention_mask, + return_dict=False, + ) + else: + module_init_outputs = self.module.init( + rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False + ) + + random_params = module_init_outputs["params"] + + if params is not None: + random_params = flatten_dict(unfreeze(random_params)) + params = flatten_dict(unfreeze(params)) + for missing_key in self._missing_keys: + params[missing_key] = random_params[missing_key] + self._missing_keys = set() + return freeze(unflatten_dict(params)) + else: + return random_params + + # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache + def init_cache(self, batch_size, max_length): + r""" + Args: + batch_size (`int`): + batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache. + max_length (`int`): + maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized + cache. + """ + # init input variables to retrieve cache + input_ids = jnp.ones((batch_size, max_length), dtype="i4") + attention_mask = jnp.ones_like(input_ids, dtype="i4") + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + init_variables = self.module.init( + jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True + ) + return unfreeze(init_variables["cache"]) + + @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + def __call__( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + params: dict = None, + dropout_rng: jax.random.PRNGKey = None, + train: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + past_key_values: dict = None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.return_dict + + # init input tensors if not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + if position_ids is None: + position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id) + + if attention_mask is None: + attention_mask = jnp.ones_like(input_ids) + + if head_mask is None: + head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) + + # Handle any PRNG if needed + rngs = {} + if dropout_rng is not None: + rngs["dropout"] = dropout_rng + + inputs = {"params": params or self.params} + + if self.config.add_cross_attention: + # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed + # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be + # changed by FlaxXLMRobertaAttention module + if past_key_values: + inputs["cache"] = past_key_values + mutable = ["cache"] + else: + mutable = False + + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + mutable=mutable, + ) + + # add updated cache to model output + if past_key_values is not None and return_dict: + outputs, past_key_values = outputs + outputs["past_key_values"] = unfreeze(past_key_values["cache"]) + return outputs + elif past_key_values is not None and not return_dict: + outputs, past_key_values = outputs + outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:] + + else: + outputs = self.module.apply( + inputs, + jnp.array(input_ids, dtype="i4"), + jnp.array(attention_mask, dtype="i4"), + token_type_ids=jnp.array(token_type_ids, dtype="i4"), + position_ids=jnp.array(position_ids, dtype="i4"), + head_mask=jnp.array(head_mask, dtype="i4"), + deterministic=not train, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + rngs=rngs, + ) + + return outputs + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->XLMRoberta +class FlaxXLMRobertaModule(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + add_pooling_layer: bool = True + gradient_checkpointing: bool = False + + def setup(self): + self.embeddings = FlaxXLMRobertaEmbeddings(self.config, dtype=self.dtype) + self.encoder = FlaxXLMRobertaEncoder( + self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.pooler = FlaxXLMRobertaPooler(self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids: Optional[jnp.ndarray] = None, + position_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # make sure `token_type_ids` is correctly initialized when not passed + if token_type_ids is None: + token_type_ids = jnp.zeros_like(input_ids) + + # make sure `position_ids` is correctly initialized when not passed + if position_ids is None: + position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) + + hidden_states = self.embeddings( + input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic + ) + outputs = self.encoder( + hidden_states, + attention_mask, + head_mask=head_mask, + deterministic=deterministic, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] + pooled = self.pooler(hidden_states) if self.add_pooling_layer else None + + if not return_dict: + # if pooled is None, don't return it + if pooled is None: + return (hidden_states,) + outputs[1:] + return (hidden_states, pooled) + outputs[1:] + + return FlaxBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=hidden_states, + pooler_output=pooled, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) @add_start_docstrings( - """XLM-RoBERTa Model with a `language modeling` head on top.""", + "The bare XLM RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", XLM_ROBERTA_START_DOCSTRING, ) -class FlaxXLMRobertaForMaskedLM(FlaxRobertaForMaskedLM): - """ - This class overrides [`FlaxRobertaForMaskedLM`]. Please check the superclass for the appropriate documentation - alongside usage examples. - """ +class FlaxXLMRobertaModel(FlaxXLMRobertaPreTrainedModel): + module_class = FlaxXLMRobertaModule - config_class = XLMRobertaConfig + +append_call_sample_docstring( + FlaxXLMRobertaModel, _TOKENIZER_FOR_DOC, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC +) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForMaskedLMModule with Roberta->XLMRoberta +class FlaxXLMRobertaForMaskedLMModule(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxXLMRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxXLMRobertaLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxMaskedLMOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +@add_start_docstrings("""XLM RoBERTa Model with a `language modeling` head on top.""", XLM_ROBERTA_START_DOCSTRING) +class FlaxXLMRobertaForMaskedLM(FlaxXLMRobertaPreTrainedModel): + module_class = FlaxXLMRobertaForMaskedLMModule + + +append_call_sample_docstring( + FlaxXLMRobertaForMaskedLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxBaseModelOutputWithPooling, + _CONFIG_FOR_DOC, + mask="", +) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForSequenceClassificationModule with Roberta->XLMRoberta +class FlaxXLMRobertaForSequenceClassificationModule(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxXLMRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.classifier = FlaxXLMRobertaClassificationHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = outputs[0] + logits = self.classifier(sequence_output, deterministic=deterministic) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxSequenceClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( """ - XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + XLM Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, XLM_ROBERTA_START_DOCSTRING, ) -class FlaxXLMRobertaForSequenceClassification(FlaxRobertaForSequenceClassification): - """ - This class overrides [`FlaxRobertaForSequenceClassification`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ +class FlaxXLMRobertaForSequenceClassification(FlaxXLMRobertaPreTrainedModel): + module_class = FlaxXLMRobertaForSequenceClassificationModule - config_class = XLMRobertaConfig + +append_call_sample_docstring( + FlaxXLMRobertaForSequenceClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxSequenceClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->XLMRoberta, with self.bert->self.roberta +class FlaxXLMRobertaForMultipleChoiceModule(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxXLMRobertaModule( + config=self.config, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) + self.classifier = nn.Dense(1, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + num_choices = input_ids.shape[1] + input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None + attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None + token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None + position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None + + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, deterministic=deterministic) + logits = self.classifier(pooled_output) + + reshaped_logits = logits.reshape(-1, num_choices) + + if not return_dict: + return (reshaped_logits,) + outputs[2:] + + return FlaxMultipleChoiceModelOutput( + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( """ - XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + XLM Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """, XLM_ROBERTA_START_DOCSTRING, ) -class FlaxXLMRobertaForMultipleChoice(FlaxRobertaForMultipleChoice): - """ - This class overrides [`FlaxRobertaForMultipleChoice`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ +class FlaxXLMRobertaForMultipleChoice(FlaxXLMRobertaPreTrainedModel): + module_class = FlaxXLMRobertaForMultipleChoiceModule - config_class = XLMRobertaConfig + +overwrite_call_docstring( + FlaxXLMRobertaForMultipleChoice, XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") +) +append_call_sample_docstring( + FlaxXLMRobertaForMultipleChoice, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxMultipleChoiceModelOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->XLMRoberta, with self.bert->self.roberta +class FlaxXLMRobertaForTokenClassificationModule(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxXLMRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + classifier_dropout = ( + self.config.classifier_dropout + if self.config.classifier_dropout is not None + else self.config.hidden_dropout_prob + ) + self.dropout = nn.Dropout(rate=classifier_dropout) + self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + hidden_states = self.dropout(hidden_states, deterministic=deterministic) + logits = self.classifier(hidden_states) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxTokenClassifierOutput( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( """ - XLM-RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + XLM Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, XLM_ROBERTA_START_DOCSTRING, ) -class FlaxXLMRobertaForTokenClassification(FlaxRobertaForTokenClassification): - """ - This class overrides [`FlaxRobertaForTokenClassification`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ +class FlaxXLMRobertaForTokenClassification(FlaxXLMRobertaPreTrainedModel): + module_class = FlaxXLMRobertaForTokenClassificationModule - config_class = XLMRobertaConfig + +append_call_sample_docstring( + FlaxXLMRobertaForTokenClassification, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxTokenClassifierOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->XLMRoberta, with self.bert->self.roberta +class FlaxXLMRobertaForQuestionAnsweringModule(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxXLMRobertaModule( + config=self.config, + dtype=self.dtype, + add_pooling_layer=False, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + + logits = self.qa_outputs(hidden_states) + start_logits, end_logits = logits.split(self.config.num_labels, axis=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + if not return_dict: + return (start_logits, end_logits) + outputs[1:] + + return FlaxQuestionAnsweringModelOutput( + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) @add_start_docstrings( """ - XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + XLM Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, XLM_ROBERTA_START_DOCSTRING, ) -class FlaxXLMRobertaForQuestionAnswering(FlaxRobertaForQuestionAnswering): - """ - This class overrides [`FlaxRobertaForQuestionAnswering`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ +class FlaxXLMRobertaForQuestionAnswering(FlaxXLMRobertaPreTrainedModel): + module_class = FlaxXLMRobertaForQuestionAnsweringModule - config_class = XLMRobertaConfig + +append_call_sample_docstring( + FlaxXLMRobertaForQuestionAnswering, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxQuestionAnsweringModelOutput, + _CONFIG_FOR_DOC, +) + + +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLMModule with Roberta->XLMRoberta +class FlaxXLMRobertaForCausalLMModule(nn.Module): + config: XLMRobertaConfig + dtype: jnp.dtype = jnp.float32 + gradient_checkpointing: bool = False + + def setup(self): + self.roberta = FlaxXLMRobertaModule( + config=self.config, + add_pooling_layer=False, + dtype=self.dtype, + gradient_checkpointing=self.gradient_checkpointing, + ) + self.lm_head = FlaxXLMRobertaLMHead(config=self.config, dtype=self.dtype) + + def __call__( + self, + input_ids, + attention_mask, + position_ids, + token_type_ids: Optional[jnp.ndarray] = None, + head_mask: Optional[jnp.ndarray] = None, + encoder_hidden_states: Optional[jnp.ndarray] = None, + encoder_attention_mask: Optional[jnp.ndarray] = None, + init_cache: bool = False, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + # Model + outputs = self.roberta( + input_ids, + attention_mask, + token_type_ids, + position_ids, + head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + init_cache=init_cache, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.tie_word_embeddings: + shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"] + else: + shared_embedding = None + + # Compute the prediction scores + logits = self.lm_head(hidden_states, shared_embedding=shared_embedding) + + if not return_dict: + return (logits,) + outputs[1:] + + return FlaxCausalLMOutputWithCrossAttentions( + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + +@add_start_docstrings( + """ + XLM Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for + autoregressive tasks. + """, + XLM_ROBERTA_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_flax_roberta.FlaxRobertaForCausalLM with Roberta->XLMRoberta +class FlaxXLMRobertaForCausalLM(FlaxXLMRobertaPreTrainedModel): + module_class = FlaxXLMRobertaForCausalLMModule + + def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None): + # initializing the cache + batch_size, seq_length = input_ids.shape + + past_key_values = self.init_cache(batch_size, max_length) + # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length. + # But since the decoder uses a causal mask, those positions are masked anyway. + # Thus, we can create a single static attention_mask here, which is more efficient for compilation + extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4") + if attention_mask is not None: + position_ids = attention_mask.cumsum(axis=-1) - 1 + extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0)) + else: + position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length)) + + return { + "past_key_values": past_key_values, + "attention_mask": extended_attention_mask, + "position_ids": position_ids, + } + + def update_inputs_for_generation(self, model_outputs, model_kwargs): + model_kwargs["past_key_values"] = model_outputs.past_key_values + model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1 + return model_kwargs + + +append_call_sample_docstring( + FlaxXLMRobertaForCausalLM, + _TOKENIZER_FOR_DOC, + _CHECKPOINT_FOR_DOC, + FlaxCausalLMOutputWithCrossAttentions, + _CONFIG_FOR_DOC, +) diff --git a/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py index d7bdd92fc9..a780473650 100644 --- a/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_tf_xlm_roberta.py @@ -15,26 +15,65 @@ # limitations under the License. """ TF 2.0 XLM-RoBERTa model.""" -from ...utils import add_start_docstrings, logging -from ..roberta.modeling_tf_roberta import ( - TFRobertaForCausalLM, - TFRobertaForMaskedLM, - TFRobertaForMultipleChoice, - TFRobertaForQuestionAnswering, - TFRobertaForSequenceClassification, - TFRobertaForTokenClassification, - TFRobertaModel, +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import tensorflow as tf + +from ...activations_tf import get_tf_activation +from ...modeling_tf_outputs import ( + TFBaseModelOutputWithPastAndCrossAttentions, + TFBaseModelOutputWithPoolingAndCrossAttentions, + TFCausalLMOutputWithCrossAttentions, + TFMaskedLMOutput, + TFMultipleChoiceModelOutput, + TFQuestionAnsweringModelOutput, + TFSequenceClassifierOutput, + TFTokenClassifierOutput, +) +from ...modeling_tf_utils import ( + TFCausalLanguageModelingLoss, + TFMaskedLanguageModelingLoss, + TFModelInputType, + TFMultipleChoiceLoss, + TFPreTrainedModel, + TFQuestionAnsweringLoss, + TFSequenceClassificationLoss, + TFTokenClassificationLoss, + get_initializer, + keras_serializable, + unpack_inputs, +) +from ...tf_utils import shape_list, stable_softmax +from ...utils import ( + DUMMY_INPUTS, + MULTIPLE_CHOICE_DUMMY_INPUTS, + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + logging, ) from .configuration_xlm_roberta import XLMRobertaConfig logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "xlm-roberta-base" +_CONFIG_FOR_DOC = "XLMRobertaConfig" +_TOKENIZER_FOR_DOC = "XLMRobertaTokenizer" + TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "xlm-roberta-base", + "xlm-roberta-large", + "joeddav/xlm-roberta-large-xnli", + "cardiffnlp/twitter-xlm-roberta-base-sentiment", # See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta ] - XLM_ROBERTA_START_DOCSTRING = r""" This model inherits from [`TFPreTrainedModel`]. Check the superclass documentation for the generic methods the @@ -77,105 +116,1606 @@ XLM_ROBERTA_START_DOCSTRING = r""" configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ +XLM_ROBERTA_INPUTS_DOCSTRING = r""" + Args: + input_ids (`Numpy array` or `tf.Tensor` of shape `({0})`): + Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`XLMRobertaTokenizer`]. + See [`PreTrainedTokenizer.__call__`] and [`PreTrainedTokenizer.encode`] for details. [What are input + IDs?](../glossary#input-ids) + attention_mask (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + token_type_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, + 1]`: + - 0 corresponds to a *sentence A* token, + - 1 corresponds to a *sentence B* token. + [What are token type IDs?](../glossary#token-type-ids) + position_ids (`Numpy array` or `tf.Tensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. [What are position IDs?](../glossary#position-ids) + head_mask (`Numpy array` or `tf.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*): + Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`: + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + inputs_embeds (`tf.Tensor` of shape `({0}, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. This argument can be used only in eager mode, in graph mode the value in the + config will be used instead. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. This argument can be used only in eager mode, in graph mode the value in the config will be + used instead. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. This argument can be used in + eager mode, in graph mode the value will always be set to True. + training (`bool`, *optional*, defaults to `False`): + Whether or not to use the model in training mode (some modules like dropout modules have different + behaviors between training and evaluation). +""" -@add_start_docstrings( - "The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", - XLM_ROBERTA_START_DOCSTRING, -) -class TFXLMRobertaModel(TFRobertaModel): + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaEmbeddings with Roberta->XLMRoberta +class TFXLMRobertaEmbeddings(tf.keras.layers.Layer): """ - This class overrides [`TFRobertaModel`]. Please check the superclass for the appropriate documentation alongside - usage examples. + Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. + """ + + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + + self.padding_idx = 1 + self.vocab_size = config.vocab_size + self.type_vocab_size = config.type_vocab_size + self.hidden_size = config.hidden_size + self.max_position_embeddings = config.max_position_embeddings + self.initializer_range = config.initializer_range + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def build(self, input_shape: tf.TensorShape): + with tf.name_scope("word_embeddings"): + self.weight = self.add_weight( + name="weight", + shape=[self.vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("token_type_embeddings"): + self.token_type_embeddings = self.add_weight( + name="embeddings", + shape=[self.type_vocab_size, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + with tf.name_scope("position_embeddings"): + self.position_embeddings = self.add_weight( + name="embeddings", + shape=[self.max_position_embeddings, self.hidden_size], + initializer=get_initializer(self.initializer_range), + ) + + super().build(input_shape) + + def create_position_ids_from_input_ids(self, input_ids, past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding + symbols are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + input_ids: tf.Tensor + Returns: tf.Tensor + """ + mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype) + incremental_indices = (tf.math.cumsum(mask, axis=1) + past_key_values_length) * mask + + return incremental_indices + self.padding_idx + + def call( + self, + input_ids=None, + position_ids=None, + token_type_ids=None, + inputs_embeds=None, + past_key_values_length=0, + training=False, + ): + """ + Applies embedding based on inputs tensor. + + Returns: + final_embeddings (`tf.Tensor`): output embedding tensor. + """ + assert not (input_ids is None and inputs_embeds is None) + + if input_ids is not None: + # Note: tf.gather, on which the embedding layer is based, won't check positive out of bound + # indices on GPU, returning zeros instead. This is a dangerous silent behavior. + tf.debugging.assert_less( + input_ids, + tf.cast(self.vocab_size, dtype=input_ids.dtype), + message=( + "input_ids must be smaller than the embedding layer's input dimension (got" + f" {tf.math.reduce_max(input_ids)} >= {self.vocab_size})" + ), + ) + inputs_embeds = tf.gather(params=self.weight, indices=input_ids) + + input_shape = shape_list(inputs_embeds)[:-1] + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + if position_ids is None: + if input_ids is not None: + # Create the position ids from the input token ids. Any padded tokens remain padded. + position_ids = self.create_position_ids_from_input_ids( + input_ids=input_ids, past_key_values_length=past_key_values_length + ) + else: + position_ids = tf.expand_dims( + tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1), axis=0 + ) + + position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids) + token_type_embeds = tf.gather(params=self.token_type_embeddings, indices=token_type_ids) + final_embeddings = inputs_embeds + position_embeds + token_type_embeds + final_embeddings = self.LayerNorm(inputs=final_embeddings) + final_embeddings = self.dropout(inputs=final_embeddings, training=training) + + return final_embeddings + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->XLMRoberta +class TFXLMRobertaPooler(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[:, 0] + pooled_output = self.dense(inputs=first_token_tensor) + + return pooled_output + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfAttention with Bert->XLMRoberta +class TFXLMRobertaSelfAttention(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number " + f"of attention heads ({config.num_attention_heads})" + ) + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.sqrt_att_head_size = math.sqrt(self.attention_head_size) + + self.query = tf.keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query" + ) + self.key = tf.keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="key" + ) + self.value = tf.keras.layers.Dense( + units=self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value" + ) + self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob) + + self.is_decoder = config.is_decoder + + def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor: + # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] + tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size)) + + # Transpose the tensor from [batch_size, seq_length, num_attention_heads, attention_head_size] to [batch_size, num_attention_heads, seq_length, attention_head_size] + return tf.transpose(tensor, perm=[0, 2, 1, 3]) + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + batch_size = shape_list(hidden_states)[0] + mixed_query_layer = self.query(inputs=hidden_states) + + # If this is instantiated as a cross-attention module, the keys + # and values come from an encoder; the attention mask needs to be + # such that the encoder's padding tokens are not attended to. + is_cross_attention = encoder_hidden_states is not None + + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_layer = past_key_value[0] + value_layer = past_key_value[1] + attention_mask = encoder_attention_mask + elif is_cross_attention: + key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size) + attention_mask = encoder_attention_mask + elif past_key_value is not None: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + key_layer = tf.concat([past_key_value[0], key_layer], axis=2) + value_layer = tf.concat([past_key_value[1], value_layer], axis=2) + else: + key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size) + value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size) + + query_layer = self.transpose_for_scores(mixed_query_layer, batch_size) + + if self.is_decoder: + # if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_layer, value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + # (batch size, num_heads, seq_len_q, seq_len_k) + attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) + dk = tf.cast(self.sqrt_att_head_size, dtype=attention_scores.dtype) + attention_scores = tf.divide(attention_scores, dk) + + if attention_mask is not None: + # Apply the attention mask is (precomputed for all layers in TFXLMRobertaModel call() function) + attention_scores = tf.add(attention_scores, attention_mask) + + # Normalize the attention scores to probabilities. + attention_probs = stable_softmax(logits=attention_scores, axis=-1) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(inputs=attention_probs, training=training) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = tf.multiply(attention_probs, head_mask) + + attention_output = tf.matmul(attention_probs, value_layer) + attention_output = tf.transpose(attention_output, perm=[0, 2, 1, 3]) + + # (batch_size, seq_len_q, all_head_size) + attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size)) + outputs = (attention_output, attention_probs) if output_attentions else (attention_output,) + + if self.is_decoder: + outputs = outputs + (past_key_value,) + return outputs + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertSelfOutput with Bert->XLMRoberta +class TFXLMRobertaSelfOutput(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertAttention with Bert->XLMRoberta +class TFXLMRobertaAttention(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.self_attention = TFXLMRobertaSelfAttention(config, name="self") + self.dense_output = TFXLMRobertaSelfOutput(config, name="output") + + def prune_heads(self, heads): + raise NotImplementedError + + def call( + self, + input_tensor: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: tf.Tensor, + encoder_attention_mask: tf.Tensor, + past_key_value: Tuple[tf.Tensor], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + self_outputs = self.self_attention( + hidden_states=input_tensor, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self.dense_output( + hidden_states=self_outputs[0], input_tensor=input_tensor, training=training + ) + # add attentions (possibly with past_key_value) if we output them + outputs = (attention_output,) + self_outputs[1:] + + return outputs + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->XLMRoberta +class TFXLMRobertaIntermediate(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = get_tf_activation(config.hidden_act) + else: + self.intermediate_act_fn = config.hidden_act + + def call(self, hidden_states: tf.Tensor) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + + return hidden_states + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->XLMRoberta +class TFXLMRobertaOutput(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.dense = tf.keras.layers.Dense( + units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm") + self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob) + + def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor: + hidden_states = self.dense(inputs=hidden_states) + hidden_states = self.dropout(inputs=hidden_states, training=training) + hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor) + + return hidden_states + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertLayer with Bert->XLMRoberta +class TFXLMRobertaLayer(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + + self.attention = TFXLMRobertaAttention(config, name="attention") + self.is_decoder = config.is_decoder + self.add_cross_attention = config.add_cross_attention + if self.add_cross_attention: + if not self.is_decoder: + raise ValueError(f"{self} should be used as a decoder model if cross attention is added") + self.crossattention = TFXLMRobertaAttention(config, name="crossattention") + self.intermediate = TFXLMRobertaIntermediate(config, name="intermediate") + self.bert_output = TFXLMRobertaOutput(config, name="output") + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_value: Optional[Tuple[tf.Tensor]], + output_attentions: bool, + training: bool = False, + ) -> Tuple[tf.Tensor]: + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attention_outputs = self.attention( + input_tensor=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=None, + encoder_attention_mask=None, + past_key_value=self_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = self_attention_outputs[0] + + # if decoder, the last output is tuple of self-attn cache + if self.is_decoder: + outputs = self_attention_outputs[1:-1] + present_key_value = self_attention_outputs[-1] + else: + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + cross_attn_present_key_value = None + if self.is_decoder and encoder_hidden_states is not None: + if not hasattr(self, "crossattention"): + raise ValueError( + f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" + " by setting `config.add_cross_attention=True`" + ) + + # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attention_outputs = self.crossattention( + input_tensor=attention_output, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + training=training, + ) + attention_output = cross_attention_outputs[0] + outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights + + # add cross-attn cache to positions 3,4 of present_key_value tuple + cross_attn_present_key_value = cross_attention_outputs[-1] + present_key_value = present_key_value + cross_attn_present_key_value + + intermediate_output = self.intermediate(hidden_states=attention_output) + layer_output = self.bert_output( + hidden_states=intermediate_output, input_tensor=attention_output, training=training + ) + outputs = (layer_output,) + outputs # add attentions if we output them + + # if decoder, return the attn key/values as the last output + if self.is_decoder: + outputs = outputs + (present_key_value,) + + return outputs + + +# Copied from transformers.models.bert.modeling_tf_bert.TFBertEncoder with Bert->XLMRoberta +class TFXLMRobertaEncoder(tf.keras.layers.Layer): + def __init__(self, config: XLMRobertaConfig, **kwargs): + super().__init__(**kwargs) + self.config = config + self.layer = [TFXLMRobertaLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)] + + def call( + self, + hidden_states: tf.Tensor, + attention_mask: tf.Tensor, + head_mask: tf.Tensor, + encoder_hidden_states: Optional[tf.Tensor], + encoder_attention_mask: Optional[tf.Tensor], + past_key_values: Optional[Tuple[Tuple[tf.Tensor]]], + use_cache: Optional[bool], + output_attentions: bool, + output_hidden_states: bool, + return_dict: bool, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]: + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None + + next_decoder_cache = () if use_cache else None + for i, layer_module in enumerate(self.layer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + past_key_value = past_key_values[i] if past_key_values is not None else None + + layer_outputs = layer_module( + hidden_states=hidden_states, + attention_mask=attention_mask, + head_mask=head_mask[i], + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + training=training, + ) + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[-1],) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + if self.config.add_cross_attention and encoder_hidden_states is not None: + all_cross_attentions = all_cross_attentions + (layer_outputs[2],) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None + ) + + return TFBaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_decoder_cache, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +@keras_serializable +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaMainLayer with Roberta->XLMRoberta +class TFXLMRobertaMainLayer(tf.keras.layers.Layer): + config_class = XLMRobertaConfig + + def __init__(self, config, add_pooling_layer=True, **kwargs): + super().__init__(**kwargs) + + self.config = config + self.is_decoder = config.is_decoder + + self.num_hidden_layers = config.num_hidden_layers + self.initializer_range = config.initializer_range + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.return_dict = config.use_return_dict + self.encoder = TFXLMRobertaEncoder(config, name="encoder") + self.pooler = TFXLMRobertaPooler(config, name="pooler") if add_pooling_layer else None + # The embeddings must be the last declaration in order to follow the weights order + self.embeddings = TFXLMRobertaEmbeddings(config, name="embeddings") + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings + def get_input_embeddings(self) -> tf.keras.layers.Layer: + return self.embeddings + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings + def set_input_embeddings(self, value: tf.Variable): + self.embeddings.weight = value + self.embeddings.vocab_size = shape_list(value)[0] + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads + def _prune_heads(self, heads_to_prune): + """ + Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base + class PreTrainedModel + """ + raise NotImplementedError + + @unpack_inputs + # Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.call + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: bool = False, + ) -> Union[TFBaseModelOutputWithPoolingAndCrossAttentions, Tuple[tf.Tensor]]: + + if not self.config.is_decoder: + use_cache = False + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = shape_list(input_ids) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + batch_size, seq_length = input_shape + + if past_key_values is None: + past_key_values_length = 0 + past_key_values = [None] * len(self.encoder.layer) + else: + past_key_values_length = shape_list(past_key_values[0][0])[-2] + + if attention_mask is None: + attention_mask = tf.fill(dims=(batch_size, seq_length + past_key_values_length), value=1) + + if token_type_ids is None: + token_type_ids = tf.fill(dims=input_shape, value=0) + + embedding_output = self.embeddings( + input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds, + past_key_values_length=past_key_values_length, + training=training, + ) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask_shape = shape_list(attention_mask) + + mask_seq_length = seq_length + past_key_values_length + # Copied from `modeling_tf_t5.py` + # Provided a padding mask of dimensions [batch_size, mask_seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + if self.is_decoder: + seq_ids = tf.range(mask_seq_length) + causal_mask = tf.less_equal( + tf.tile(seq_ids[None, None, :], (batch_size, mask_seq_length, 1)), + seq_ids[None, :, None], + ) + causal_mask = tf.cast(causal_mask, dtype=attention_mask.dtype) + extended_attention_mask = causal_mask * attention_mask[:, None, :] + attention_mask_shape = shape_list(extended_attention_mask) + extended_attention_mask = tf.reshape( + extended_attention_mask, (attention_mask_shape[0], 1, attention_mask_shape[1], attention_mask_shape[2]) + ) + if past_key_values[0] is not None: + # attention_mask needs to be sliced to the shape `[batch_size, 1, from_seq_length - cached_seq_length, to_seq_length] + extended_attention_mask = extended_attention_mask[:, :, -seq_length:, :] + else: + extended_attention_mask = tf.reshape( + attention_mask, (attention_mask_shape[0], 1, 1, attention_mask_shape[1]) + ) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = tf.cast(extended_attention_mask, dtype=embedding_output.dtype) + one_cst = tf.constant(1.0, dtype=embedding_output.dtype) + ten_thousand_cst = tf.constant(-10000.0, dtype=embedding_output.dtype) + extended_attention_mask = tf.multiply(tf.subtract(one_cst, extended_attention_mask), ten_thousand_cst) + + # Copied from `modeling_tf_t5.py` with -1e9 -> -10000 + if self.is_decoder and encoder_attention_mask is not None: + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, mask_seq_length, mask_seq_length] + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + encoder_attention_mask = tf.cast(encoder_attention_mask, dtype=extended_attention_mask.dtype) + num_dims_encoder_attention_mask = len(shape_list(encoder_attention_mask)) + if num_dims_encoder_attention_mask == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if num_dims_encoder_attention_mask == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition + # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow/transformer/transformer_layers.py#L270 + # encoder_extended_attention_mask = tf.math.equal(encoder_extended_attention_mask, + # tf.transpose(encoder_extended_attention_mask, perm=(-1, -2))) + + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + else: + encoder_extended_attention_mask = None + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + raise NotImplementedError + else: + head_mask = [None] * self.config.num_hidden_layers + + encoder_outputs = self.encoder( + hidden_states=embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = encoder_outputs[0] + pooled_output = self.pooler(hidden_states=sequence_output) if self.pooler is not None else None + + if not return_dict: + return ( + sequence_output, + pooled_output, + ) + encoder_outputs[1:] + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=sequence_output, + pooler_output=pooled_output, + past_key_values=encoder_outputs.past_key_values, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + cross_attentions=encoder_outputs.cross_attentions, + ) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaPreTrainedModel with Roberta->XLMRoberta +class TFXLMRobertaPreTrainedModel(TFPreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. """ config_class = XLMRobertaConfig + base_model_prefix = "roberta" + + @property + # Copied from transformers.models.bert.modeling_tf_bert.TFBertPreTrainedModel.dummy_inputs + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + `Dict[str, tf.Tensor]`: The dummy inputs. + """ + dummy = {"input_ids": tf.constant(DUMMY_INPUTS, dtype=tf.int32)} + # Add `encoder_hidden_states` to make the cross-attention layers' weights initialized + if self.config.add_cross_attention: + batch_size, seq_len = tf.constant(DUMMY_INPUTS).shape + shape = (batch_size, seq_len) + (self.config.hidden_size,) + h = tf.random.uniform(shape=shape) + dummy["encoder_hidden_states"] = h + + return dummy + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None), tf.int32, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + +@add_start_docstrings( + "The bare XLM RoBERTa Model transformer outputting raw hidden-states without any specific head on top.", + XLM_ROBERTA_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaModel with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA +class TFXLMRobertaModel(TFXLMRobertaPreTrainedModel): + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.roberta = TFXLMRobertaMainLayer(config, name="roberta") + + @unpack_inputs + @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFBaseModelOutputWithPoolingAndCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + training: Optional[bool] = False, + ) -> Union[Tuple, TFBaseModelOutputWithPoolingAndCrossAttentions]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + return outputs + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output + def serving_output( + self, output: TFBaseModelOutputWithPoolingAndCrossAttentions + ) -> TFBaseModelOutputWithPoolingAndCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None + + return TFBaseModelOutputWithPoolingAndCrossAttentions( + last_hidden_state=output.last_hidden_state, + pooler_output=output.pooler_output, + past_key_values=pkv, + hidden_states=hs, + attentions=attns, + cross_attentions=cross_attns, + ) + + +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaLMHead with Roberta->XLMRoberta +class TFXLMRobertaLMHead(tf.keras.layers.Layer): + """XLMRoberta Head for masked language modeling.""" + + def __init__(self, config, input_embeddings, **kwargs): + super().__init__(**kwargs) + + self.vocab_size = config.vocab_size + self.hidden_size = config.hidden_size + self.dense = tf.keras.layers.Dense( + config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense" + ) + self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") + self.act = get_tf_activation("gelu") + + # The output weights are the same as the input embeddings, but there is + # an output-only bias for each token. + self.decoder = input_embeddings + + def build(self, input_shape): + self.bias = self.add_weight(shape=(self.vocab_size,), initializer="zeros", trainable=True, name="bias") + + super().build(input_shape) + + def get_output_embeddings(self): + return self.decoder + + def set_output_embeddings(self, value): + self.decoder.weight = value + self.decoder.vocab_size = shape_list(value)[0] + + def get_bias(self): + return {"bias": self.bias} + + def set_bias(self, value): + self.bias = value["bias"] + self.vocab_size = shape_list(value["bias"])[0] + + def call(self, hidden_states): + hidden_states = self.dense(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.layer_norm(hidden_states) + + # project back to size of vocabulary with bias + seq_length = shape_list(tensor=hidden_states)[1] + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size]) + hidden_states = tf.matmul(a=hidden_states, b=self.decoder.weight, transpose_b=True) + hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.vocab_size]) + hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias) + + return hidden_states + + +@add_start_docstrings("""XLM RoBERTa Model with a `language modeling` head on top.""", XLM_ROBERTA_START_DOCSTRING) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMaskedLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA +class TFXLMRobertaForMaskedLM(TFXLMRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFXLMRobertaLMHead(config, self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + @unpack_inputs + @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMaskedLMOutput, + config_class=_CONFIG_FOR_DOC, + mask="", + expected_output="' Paris'", + expected_loss=0.1, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + ) -> Union[TFMaskedLMOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., + config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the + loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + prediction_scores = self.lm_head(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, prediction_scores) + + if not return_dict: + output = (prediction_scores,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMaskedLMOutput( + loss=loss, + logits=prediction_scores, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output + def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFMaskedLMOutput(logits=output.logits, hidden_states=hs, attentions=attns) @add_start_docstrings( "XLM-RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.", XLM_ROBERTA_START_DOCSTRING, ) -class XLMRobertaForCausalLM(TFRobertaForCausalLM): - """ - This class overrides [`TFRobertaForCausalLM`]. Please check the superclass for the appropriate documentation - alongside usage examples. - """ +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForCausalLM with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA +class TFXLMRobertaForCausalLM(TFXLMRobertaPreTrainedModel, TFCausalLanguageModelingLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head.decoder.weight"] - config_class = XLMRobertaConfig + def __init__(self, config: XLMRobertaConfig, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + if not config.is_decoder: + logger.warning("If you want to use `TFXLMRobertaLMHeadModel` as a standalone, add `is_decoder=True.`") + + self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.lm_head = TFXLMRobertaLMHead(config, input_embeddings=self.roberta.embeddings, name="lm_head") + + def get_lm_head(self): + return self.lm_head + + def get_prefix_bias_name(self): + warnings.warn("The method get_prefix_bias_name is deprecated. Please use `get_bias` instead.", FutureWarning) + return self.name + "/" + self.lm_head.name + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.prepare_inputs_for_generation + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs): + input_shape = input_ids.shape + # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly + if attention_mask is None: + attention_mask = tf.ones(input_shape) + + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} + + @unpack_inputs + @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFCausalLMOutputWithCrossAttentions, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_hidden_states: Optional[Union[np.ndarray, tf.Tensor]] = None, + encoder_attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + past_key_values: Optional[Tuple[Tuple[Union[np.ndarray, tf.Tensor]]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + ) -> Union[TFCausalLMOutputWithCrossAttentions, Tuple[tf.Tensor]]: + r""" + encoder_hidden_states (`tf.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if + the model is configured as a decoder. + encoder_attention_mask (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in + the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + past_key_values (`Tuple[Tuple[tf.Tensor]]` of length `config.n_layers`) + contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + use_cache (`bool`, *optional*, defaults to `True`): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). Set to `False` during training, `True` during generation + labels (`tf.Tensor` or `np.ndarray` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., + config.vocab_size - 1]`. + """ + outputs = self.roberta( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + + sequence_output = outputs[0] + logits = self.lm_head(hidden_states=sequence_output, training=training) + loss = None + + if labels is not None: + # shift labels to the left and cut last logit token + shifted_logits = logits[:, :-1] + labels = labels[:, 1:] + loss = self.hf_compute_loss(labels=labels, logits=shifted_logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFCausalLMOutputWithCrossAttentions( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertLMHeadModel.serving_output + def serving_output(self, output: TFCausalLMOutputWithCrossAttentions) -> TFCausalLMOutputWithCrossAttentions: + output_cache = self.config.use_cache and self.config.is_decoder + pkv = tf.convert_to_tensor(output.past_key_values) if output_cache else None + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + cross_attns = tf.convert_to_tensor(output.cross_attentions) if output.cross_attentions is not None else None + if not (self.config.output_attentions and self.config.add_cross_attention): + cross_attns = None + + return TFCausalLMOutputWithCrossAttentions( + logits=output.logits, past_key_values=pkv, hidden_states=hs, attentions=attns, cross_attentions=cross_attns + ) -@add_start_docstrings( - """XLM-RoBERTa Model with a `language modeling` head on top.""", - XLM_ROBERTA_START_DOCSTRING, -) -class TFXLMRobertaForMaskedLM(TFRobertaForMaskedLM): - """ - This class overrides [`TFRobertaForMaskedLM`]. Please check the superclass for the appropriate documentation - alongside usage examples. - """ +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaClassificationHead with Roberta->XLMRoberta +class TFXLMRobertaClassificationHead(tf.keras.layers.Layer): + """Head for sentence-level classification tasks.""" - config_class = XLMRobertaConfig + def __init__(self, config, **kwargs): + super().__init__(**kwargs) + self.dense = tf.keras.layers.Dense( + config.hidden_size, + kernel_initializer=get_initializer(config.initializer_range), + activation="tanh", + name="dense", + ) + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) + self.out_proj = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="out_proj" + ) + + def call(self, features, training=False): + x = features[:, 0, :] # take token (equiv. to [CLS]) + x = self.dropout(x, training=training) + x = self.dense(x) + x = self.dropout(x, training=training) + x = self.out_proj(x) + return x @add_start_docstrings( """ - XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the + XLM RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled output) e.g. for GLUE tasks. """, XLM_ROBERTA_START_DOCSTRING, ) -class TFXLMRobertaForSequenceClassification(TFRobertaForSequenceClassification): - """ - This class overrides [`TFRobertaForSequenceClassification`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForSequenceClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA +class TFXLMRobertaForSequenceClassification(TFXLMRobertaPreTrainedModel, TFSequenceClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - config_class = XLMRobertaConfig + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.classifier = TFXLMRobertaClassificationHead(config, name="classifier") + + @unpack_inputs + @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint="cardiffnlp/twitter-roberta-base-emotion", + output_type=TFSequenceClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="'optimism'", + expected_loss=0.08, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + ) -> Union[TFSequenceClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + logits = self.classifier(sequence_output, training=training) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFSequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output + def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFSequenceClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns) @add_start_docstrings( """ - XLM-RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. + XLM Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and + a softmax) e.g. for RocStories/SWAG tasks. + """, + XLM_ROBERTA_START_DOCSTRING, +) +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForMultipleChoice with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA +class TFXLMRobertaForMultipleChoice(TFXLMRobertaPreTrainedModel, TFMultipleChoiceLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] + + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + + self.roberta = TFXLMRobertaMainLayer(config, name="roberta") + self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob) + self.classifier = tf.keras.layers.Dense( + 1, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + + @property + def dummy_inputs(self): + """ + Dummy inputs to build the network. + + Returns: + tf.Tensor with dummy inputs + """ + return {"input_ids": tf.constant(MULTIPLE_CHOICE_DUMMY_INPUTS, dtype=tf.int32)} + + @unpack_inputs + @add_start_docstrings_to_model_forward( + XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length") + ) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TFMultipleChoiceModelOutput, + config_class=_CONFIG_FOR_DOC, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + ) -> Union[TFMultipleChoiceModelOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., num_choices]` + where `num_choices` is the size of the second dimension of the input tensors. (See `input_ids` above) + """ + + if input_ids is not None: + num_choices = shape_list(input_ids)[1] + seq_length = shape_list(input_ids)[2] + else: + num_choices = shape_list(inputs_embeds)[1] + seq_length = shape_list(inputs_embeds)[2] + + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None + outputs = self.roberta( + flat_input_ids, + flat_attention_mask, + flat_token_type_ids, + flat_position_ids, + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, + ) + pooled_output = outputs[1] + pooled_output = self.dropout(pooled_output, training=training) + logits = self.classifier(pooled_output) + reshaped_logits = tf.reshape(logits, (-1, num_choices)) + + loss = None if labels is None else self.hf_compute_loss(labels, reshaped_logits) + + if not return_dict: + output = (reshaped_logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFMultipleChoiceModelOutput( + loss=loss, + logits=reshaped_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + @tf.function( + input_signature=[ + { + "input_ids": tf.TensorSpec((None, None, None), tf.int32, name="input_ids"), + "attention_mask": tf.TensorSpec((None, None, None), tf.int32, name="attention_mask"), + } + ] + ) + def serving(self, inputs): + output = self.call(inputs) + + return self.serving_output(output) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output + def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFMultipleChoiceModelOutput(logits=output.logits, hidden_states=hs, attentions=attns) + + +@add_start_docstrings( + """ + XLM RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """, XLM_ROBERTA_START_DOCSTRING, ) -class TFXLMRobertaForTokenClassification(TFRobertaForTokenClassification): - """ - This class overrides [`TFRobertaForTokenClassification`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForTokenClassification with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA +class TFXLMRobertaForTokenClassification(TFXLMRobertaPreTrainedModel, TFTokenClassificationLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] + _keys_to_ignore_on_load_missing = [r"dropout"] - config_class = XLMRobertaConfig + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + classifier_dropout = ( + config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob + ) + self.dropout = tf.keras.layers.Dropout(classifier_dropout) + self.classifier = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="classifier" + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint="ydshieh/roberta-large-ner-english", + output_type=TFTokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="['O', 'ORG', 'ORG', 'O', 'O', 'O', 'O', 'O', 'LOC', 'O', 'LOC', 'LOC']", + expected_loss=0.01, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + labels: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + ) -> Union[TFTokenClassifierOutput, Tuple[tf.Tensor]]: + r""" + labels (`tf.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + sequence_output = self.dropout(sequence_output, training=training) + logits = self.classifier(sequence_output) + + loss = None if labels is None else self.hf_compute_loss(labels, logits) + + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFTokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output + def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFTokenClassifierOutput(logits=output.logits, hidden_states=hs, attentions=attns) @add_start_docstrings( """ -XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear -layers on top of the hidden-states output to compute `span start logits` and `span end logits`). -""", - XLM_ROBERTA_START_DOCSTRING, -) -class TFXLMRobertaForQuestionAnswering(TFRobertaForQuestionAnswering): - """ - This class overrides [`TFRobertaForQuestionAnsweringSimple`]. Please check the superclass for the appropriate - documentation alongside usage examples. - """ - - config_class = XLMRobertaConfig - - -@add_start_docstrings( - """ - Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a - softmax) e.g. for RocStories/SWAG tasks. + XLM RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a + linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """, XLM_ROBERTA_START_DOCSTRING, ) -class TFXLMRobertaForMultipleChoice(TFRobertaForMultipleChoice): - """ - This class overrides [`TFRobertaForMultipleChoice`]. Please check the superclass for the appropriate documentation - alongside usage examples. - """ +# Copied from transformers.models.roberta.modeling_tf_roberta.TFRobertaForQuestionAnswering with Roberta->XLMRoberta, ROBERTA->XLM_ROBERTA +class TFXLMRobertaForQuestionAnswering(TFXLMRobertaPreTrainedModel, TFQuestionAnsweringLoss): + # names with a '.' represents the authorized unexpected/missing layers when a TF model is loaded from a PT model + _keys_to_ignore_on_load_unexpected = [r"pooler", r"lm_head"] - config_class = XLMRobertaConfig + def __init__(self, config, *inputs, **kwargs): + super().__init__(config, *inputs, **kwargs) + self.num_labels = config.num_labels + + self.roberta = TFXLMRobertaMainLayer(config, add_pooling_layer=False, name="roberta") + self.qa_outputs = tf.keras.layers.Dense( + config.num_labels, kernel_initializer=get_initializer(config.initializer_range), name="qa_outputs" + ) + + @unpack_inputs + @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @add_code_sample_docstrings( + processor_class=_TOKENIZER_FOR_DOC, + checkpoint="ydshieh/roberta-base-squad2", + output_type=TFQuestionAnsweringModelOutput, + config_class=_CONFIG_FOR_DOC, + expected_output="' puppet'", + expected_loss=0.86, + ) + def call( + self, + input_ids: Optional[TFModelInputType] = None, + attention_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + token_type_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + position_ids: Optional[Union[np.ndarray, tf.Tensor]] = None, + head_mask: Optional[Union[np.ndarray, tf.Tensor]] = None, + inputs_embeds: Optional[Union[np.ndarray, tf.Tensor]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + start_positions: Optional[Union[np.ndarray, tf.Tensor]] = None, + end_positions: Optional[Union[np.ndarray, tf.Tensor]] = None, + training: Optional[bool] = False, + ) -> Union[TFQuestionAnsweringModelOutput, Tuple[tf.Tensor]]: + r""" + start_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + end_positions (`tf.Tensor` of shape `(batch_size,)`, *optional*): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence + are not taken into account for computing the loss. + """ + outputs = self.roberta( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + training=training, + ) + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = tf.split(logits, 2, axis=-1) + start_logits = tf.squeeze(start_logits, axis=-1) + end_logits = tf.squeeze(end_logits, axis=-1) + + loss = None + if start_positions is not None and end_positions is not None: + labels = {"start_position": start_positions} + labels["end_position"] = end_positions + loss = self.hf_compute_loss(labels, (start_logits, end_logits)) + + if not return_dict: + output = (start_logits, end_logits) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TFQuestionAnsweringModelOutput( + loss=loss, + start_logits=start_logits, + end_logits=end_logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + # Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output + def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput: + hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None + attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None + + return TFQuestionAnsweringModelOutput( + start_logits=output.start_logits, end_logits=output.end_logits, hidden_states=hs, attentions=attns + ) diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index d8df4ece98..774b446a96 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -286,7 +286,7 @@ class XLMRobertaSelfAttention(nn.Module): return outputs -# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput +# Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput with Roberta->XLMRoberta class XLMRobertaSelfOutput(nn.Module): def __init__(self, config): super().__init__() @@ -351,7 +351,7 @@ class XLMRobertaAttention(nn.Module): return outputs -# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate +# Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate with Roberta->XLMRoberta class XLMRobertaIntermediate(nn.Module): def __init__(self, config): super().__init__() @@ -367,7 +367,7 @@ class XLMRobertaIntermediate(nn.Module): return hidden_states -# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput +# Copied from transformers.models.roberta.modeling_roberta.RobertaOutput with Roberta->XLMRoberta class XLMRobertaOutput(nn.Module): def __init__(self, config): super().__init__() @@ -567,7 +567,7 @@ class XLMRobertaEncoder(nn.Module): ) -# Copied from transformers.models.roberta.modeling_roberta.RobertaPooler +# Copied from transformers.models.roberta.modeling_roberta.RobertaPooler with Roberta->XLMRoberta class XLMRobertaPooler(nn.Module): def __init__(self, config): super().__init__() @@ -1455,7 +1455,7 @@ class XLMRobertaForTokenClassification(XLMRobertaPreTrainedModel): ) -# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead +# Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead with Roberta->XLMRoberta class XLMRobertaClassificationHead(nn.Module): """Head for sentence-level classification tasks.""" diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index d7271900b2..8c093e3c45 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -1152,6 +1152,16 @@ class FlaxXGLMPreTrainedModel(metaclass=DummyObject): requires_backends(self, ["flax"]) +FLAX_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None + + +class FlaxXLMRobertaForCausalLM(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + class FlaxXLMRobertaForMaskedLM(metaclass=DummyObject): _backends = ["flax"] @@ -1192,3 +1202,10 @@ class FlaxXLMRobertaModel(metaclass=DummyObject): def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) + + +class FlaxXLMRobertaPreTrainedModel(metaclass=DummyObject): + _backends = ["flax"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index 624e08b88e..63d59d6ed8 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -2646,6 +2646,13 @@ class TFXLMWithLMHeadModel(metaclass=DummyObject): TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = None +class TFXLMRobertaForCausalLM(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + class TFXLMRobertaForMaskedLM(metaclass=DummyObject): _backends = ["tf"] @@ -2688,6 +2695,13 @@ class TFXLMRobertaModel(metaclass=DummyObject): requires_backends(self, ["tf"]) +class TFXLMRobertaPreTrainedModel(metaclass=DummyObject): + _backends = ["tf"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST = None