From af1a10bff4e4025881cb5d5534c603853fb8122c Mon Sep 17 00:00:00 2001 From: Jayendra Date: Fri, 28 May 2021 16:16:56 +0530 Subject: [PATCH] [Flax] Return Attention from BERT, ELECTRA, RoBERTa and GPT2 (#11918) * Added logic to return attention from flax-bert model and added test cases to check that * Added new line at the end of file to test_modeling_flax_common.py * fixing code style * Fixing Roberta and Elextra models too from cpoying bert * Added temporary hack to not run test_attention_outputs for FlaxGPT2 * Returning attention weights from GPT2 and changed the tests accordingly. * last fixes * bump flax dependency Co-authored-by: jayendra Co-authored-by: Patrick von Platen --- setup.py | 2 +- src/transformers/dependency_versions_table.py | 2 +- .../models/bert/modeling_flax_bert.py | 23 ++++----- .../models/electra/modeling_flax_electra.py | 23 ++++----- .../models/gpt2/modeling_flax_gpt2.py | 21 ++++---- .../models/roberta/modeling_flax_roberta.py | 23 ++++----- tests/test_modeling_flax_common.py | 48 ++++++++++++++++++- 7 files changed, 89 insertions(+), 53 deletions(-) diff --git a/setup.py b/setup.py index 498107ac0c..475343f88e 100644 --- a/setup.py +++ b/setup.py @@ -97,7 +97,7 @@ _deps = [ "fastapi", "filelock", "flake8>=3.8.3", - "flax>=0.3.2", + "flax>=0.3.4", "fugashi>=1.0", "huggingface-hub==0.0.8", "importlib_metadata", diff --git a/src/transformers/dependency_versions_table.py b/src/transformers/dependency_versions_table.py index 811f9d66cb..55bbcb670f 100644 --- a/src/transformers/dependency_versions_table.py +++ b/src/transformers/dependency_versions_table.py @@ -14,7 +14,7 @@ deps = { "fastapi": "fastapi", "filelock": "filelock", "flake8": "flake8>=3.8.3", - "flax": "flax>=0.3.2", + "flax": "flax>=0.3.4", "fugashi": "fugashi>=1.0", "huggingface-hub": "huggingface-hub==0.0.8", "importlib_metadata": "importlib_metadata", diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index 82ce4ee870..aa2bcd0f8f 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -23,7 +23,7 @@ import jax import jax.numpy as jnp import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict -from flax.linen import dot_product_attention +from flax.linen.attention import dot_product_attention_weights from jax import lax from ...file_utils import ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward @@ -241,10 +241,9 @@ class FlaxBertSelfAttention(nn.Module): if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") - attn_output = dot_product_attention( + attn_weights = dot_product_attention_weights( query_states, key_states, - value_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, @@ -254,11 +253,10 @@ class FlaxBertSelfAttention(nn.Module): precision=None, ) - outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) - - # TODO: at the moment it's not possible to retrieve attn_weights from - # dot_product_attention, but should be in the future -> add functionality then + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -303,7 +301,7 @@ class FlaxBertAttention(nn.Module): outputs = (hidden_states,) if output_attentions: - outputs += attn_outputs[1] + outputs += (attn_outputs[1],) return outputs @@ -396,7 +394,9 @@ class FlaxBertLayerCollection(nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) + layer_outputs = layer( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] @@ -582,11 +582,6 @@ class FlaxBertPreTrainedModel(FlaxPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.return_dict - if output_attentions: - raise NotImplementedError( - "Currently attention scores cannot be returned. Please set `output_attentions` to False for now." - ) - # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.zeros_like(input_ids) diff --git a/src/transformers/models/electra/modeling_flax_electra.py b/src/transformers/models/electra/modeling_flax_electra.py index 9d94433016..ea093770fd 100644 --- a/src/transformers/models/electra/modeling_flax_electra.py +++ b/src/transformers/models/electra/modeling_flax_electra.py @@ -23,7 +23,7 @@ import jax import jax.numpy as jnp import jaxlib.xla_extension as jax_xla from flax.core.frozen_dict import FrozenDict -from flax.linen import dot_product_attention +from flax.linen.attention import dot_product_attention_weights from jax import lax from jax.random import PRNGKey @@ -238,10 +238,9 @@ class FlaxElectraSelfAttention(nn.Module): if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") - attn_output = dot_product_attention( + attn_weights = dot_product_attention_weights( query_states, key_states, - value_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, @@ -251,11 +250,10 @@ class FlaxElectraSelfAttention(nn.Module): precision=None, ) - outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) - - # TODO: at the moment it's not possible to retrieve attn_weights from - # dot_product_attention, but should be in the future -> add functionality then + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -302,7 +300,7 @@ class FlaxElectraAttention(nn.Module): outputs = (hidden_states,) if output_attentions: - outputs += attn_outputs[1] + outputs += (attn_outputs[1],) return outputs @@ -399,7 +397,9 @@ class FlaxElectraLayerCollection(nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) + layer_outputs = layer( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] @@ -534,11 +534,6 @@ class FlaxElectraPreTrainedModel(FlaxPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.return_dict - if output_attentions: - raise NotImplementedError( - "Currently attention scores cannot be returned. Please set `output_attentions` to False for now." - ) - # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) diff --git a/src/transformers/models/gpt2/modeling_flax_gpt2.py b/src/transformers/models/gpt2/modeling_flax_gpt2.py index 19bac78c8a..5440d47c06 100644 --- a/src/transformers/models/gpt2/modeling_flax_gpt2.py +++ b/src/transformers/models/gpt2/modeling_flax_gpt2.py @@ -19,7 +19,8 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict, unfreeze -from flax.linen import combine_masks, dot_product_attention, make_causal_mask +from flax.linen import combine_masks, make_causal_mask +from flax.linen.attention import dot_product_attention_weights from jax import lax from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward @@ -215,10 +216,9 @@ class FlaxGPT2Attention(nn.Module): ) # usual dot product attention - attn_output = dot_product_attention( + attn_weights = dot_product_attention_weights( query, key, - value, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attn_pdrop, @@ -227,14 +227,13 @@ class FlaxGPT2Attention(nn.Module): precision=None, ) + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value) attn_output = self._merge_heads(attn_output) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output, deterministic=deterministic) - # TODO: at the moment it's not possible to retrieve attn_weights from - # dot_product_attention, but should be in the future -> add functionality then - - return (attn_output,) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) + return outputs class FlaxGPT2MLP(nn.Module): @@ -447,7 +446,13 @@ class FlaxGPT2BlockCollection(nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = block(hidden_states, attention_mask, deterministic=deterministic, init_cache=init_cache) + layer_outputs = block( + hidden_states, + attention_mask, + deterministic=deterministic, + init_cache=init_cache, + output_attentions=output_attentions, + ) hidden_states = layer_outputs[0] if output_attentions: diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 9613a69988..128ccd3e29 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -20,7 +20,7 @@ import flax.linen as nn import jax import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -from flax.linen import dot_product_attention +from flax.linen.attention import dot_product_attention_weights from jax import lax from jax.random import PRNGKey @@ -227,10 +227,9 @@ class FlaxRobertaSelfAttention(nn.Module): if not deterministic and self.config.attention_probs_dropout_prob > 0.0: dropout_rng = self.make_rng("dropout") - attn_output = dot_product_attention( + attn_weights = dot_product_attention_weights( query_states, key_states, - value_states, bias=attention_bias, dropout_rng=dropout_rng, dropout_rate=self.config.attention_probs_dropout_prob, @@ -240,11 +239,10 @@ class FlaxRobertaSelfAttention(nn.Module): precision=None, ) - outputs = (attn_output.reshape(attn_output.shape[:2] + (-1,)),) - - # TODO: at the moment it's not possible to retrieve attn_weights from - # dot_product_attention, but should be in the future -> add functionality then + attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states) + attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,)) + outputs = (attn_output, attn_weights) if output_attentions else (attn_output,) return outputs @@ -291,7 +289,7 @@ class FlaxRobertaAttention(nn.Module): outputs = (hidden_states,) if output_attentions: - outputs += attn_outputs[1] + outputs += (attn_outputs[1],) return outputs @@ -388,7 +386,9 @@ class FlaxRobertaLayerCollection(nn.Module): if output_hidden_states: all_hidden_states += (hidden_states,) - layer_outputs = layer(hidden_states, attention_mask, deterministic=deterministic) + layer_outputs = layer( + hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions + ) hidden_states = layer_outputs[0] @@ -570,11 +570,6 @@ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel): ) return_dict = return_dict if return_dict is not None else self.config.return_dict - if output_attentions: - raise NotImplementedError( - "Currently attention scores cannot be returned." "Please set `output_attentions` to False for now." - ) - # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.zeros_like(input_ids) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index e1c0322699..7748c5b62f 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -79,8 +79,9 @@ class FlaxModelTesterMixin: if "ForMultipleChoice" in model_class.__name__: inputs_dict = { k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1])) - for k, v in inputs_dict.items() if isinstance(v, (jax_xla.DeviceArray, np.ndarray)) + else v + for k, v in inputs_dict.items() } return inputs_dict @@ -310,3 +311,48 @@ class FlaxModelTesterMixin: config.output_hidden_states = True check_hidden_states_output(inputs_dict, config, model_class) + + def test_attention_outputs(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_length = getattr(self.model_tester, "seq_length", None) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_length, seq_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers) + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, seq_length, seq_length], + )