[FlaxWav2Vec2Model] Fix bug in attention mask (#16725)
* [FlaxWav2Vec2Model] Fix bug in attention mask * more fixes * add (Flax)SpeechEncoderDecoderModel PT-FX cross-test
This commit is contained in:
@@ -920,7 +920,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
|
|||||||
def _get_feat_extract_output_lengths(
|
def _get_feat_extract_output_lengths(
|
||||||
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
||||||
):
|
):
|
||||||
return self.module._get_feat_extract_output_lengths(input_lengths)
|
return self.module._get_feat_extract_output_lengths(input_lengths, add_adapter=add_adapter)
|
||||||
|
|
||||||
|
|
||||||
class FlaxWav2Vec2Module(nn.Module):
|
class FlaxWav2Vec2Module(nn.Module):
|
||||||
@@ -956,15 +956,10 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||||||
|
|
||||||
# make sure that no loss is computed on padded inputs
|
# make sure that no loss is computed on padded inputs
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# compute real output lengths according to convolution formula
|
# compute reduced attention_mask corresponding to feature vectors
|
||||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))
|
attention_mask = self._get_feature_vector_attention_mask(
|
||||||
|
extract_features.shape[1], attention_mask, add_adapter=False
|
||||||
attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype)
|
)
|
||||||
|
|
||||||
# these two operations makes sure that all values
|
|
||||||
# before the output lengths indices are attended to
|
|
||||||
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
|
||||||
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
|
||||||
|
|
||||||
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
|
||||||
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
|
if mask_time_indices is not None: # apply SpecAugment along time axis with given indices
|
||||||
@@ -1034,12 +1029,10 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||||||
batch_size = attention_mask.shape[0]
|
batch_size = attention_mask.shape[0]
|
||||||
|
|
||||||
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
attention_mask = jnp.zeros((batch_size, feature_vector_length), dtype=attention_mask.dtype)
|
||||||
# these two operations makes sure that all values before the output lengths idxs are attended to
|
# these two operations makes sure that all values
|
||||||
idx = (jnp.arange(attention_mask.shape[0]), output_lengths - 1)
|
# before the output lengths indices are attended to
|
||||||
attention_mask = attention_mask.at[idx].set(1)
|
attention_mask = attention_mask.at[jnp.arange(attention_mask.shape[0]), output_lengths - 1].set(1)
|
||||||
attention_mask = jnp.flip(jnp.flip(attention_mask, axis=-1).cumsum(axis=-1), axis=-1)
|
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
||||||
|
|
||||||
attention_mask = jnp.array(attention_mask, dtype=bool)
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -1286,11 +1279,15 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
|||||||
attentions=outputs.attentions,
|
attentions=outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[jnp.ndarray, int]):
|
def _get_feat_extract_output_lengths(
|
||||||
|
self, input_lengths: Union[jnp.ndarray, int], add_adapter: Optional[bool] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Computes the output length of the convolutional layers
|
Computes the output length of the convolutional layers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
||||||
|
|
||||||
def _conv_out_length(input_length, kernel_size, stride):
|
def _conv_out_length(input_length, kernel_size, stride):
|
||||||
# 1D convolutional layer output length formula taken
|
# 1D convolutional layer output length formula taken
|
||||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||||
@@ -1299,6 +1296,10 @@ class FlaxWav2Vec2ForPreTrainingModule(nn.Module):
|
|||||||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||||
|
|
||||||
|
if add_adapter:
|
||||||
|
for _ in range(self.config.num_adapter_layers):
|
||||||
|
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
||||||
|
|
||||||
return input_lengths
|
return input_lengths
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -539,6 +539,12 @@ class FlaxEncoderDecoderMixin:
|
|||||||
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||||
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
|
# check `add_adapter` works as expected
|
||||||
|
config.add_adapter = True
|
||||||
|
self.assertTrue(config.add_adapter)
|
||||||
|
self.check_equivalence_pt_to_flax(config, decoder_config, inputs_dict)
|
||||||
|
self.check_equivalence_flax_to_pt(config, decoder_config, inputs_dict)
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_real_model_save_load_from_pretrained(self):
|
def test_real_model_save_load_from_pretrained(self):
|
||||||
model_2 = self.get_pretrained_model()
|
model_2 = self.get_pretrained_model()
|
||||||
|
|||||||
Reference in New Issue
Block a user