From 9fc34235fa3329c918d5ba67ce09a0cc8f399c59 Mon Sep 17 00:00:00 2001 From: amyeroberts Date: Thu, 9 Jun 2022 15:50:50 +0200 Subject: [PATCH] Use shape_list to safely get shapes for Swin (#17591) * Use shape_list to safely get shapes * Add relevant test * Tidy and add metrics * Resolve dynamic shaping issues and move test * Tidy up and all samples in batch * Formatting --- .../models/deberta/modeling_tf_deberta.py | 7 ++- .../deberta_v2/modeling_tf_deberta_v2.py | 18 +++++--- .../models/swin/modeling_tf_swin.py | 44 +++++++++++-------- tests/test_modeling_tf_common.py | 18 ++++++++ 4 files changed, 61 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/deberta/modeling_tf_deberta.py b/src/transformers/models/deberta/modeling_tf_deberta.py index 2b369eef5d..d4f2b12b82 100644 --- a/src/transformers/models/deberta/modeling_tf_deberta.py +++ b/src/transformers/models/deberta/modeling_tf_deberta.py @@ -648,7 +648,12 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer): context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) - new_context_layer_shape = shape_list(context_layer)[:-2] + [-1] + context_layer_shape = shape_list(context_layer) + # Set the final dimension here explicitly. + # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing + # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput + # requires final input dimension to be defined + new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]] context_layer = tf.reshape(context_layer, new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs diff --git a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py index 953dd1d34f..3020a6d490 100644 --- a/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_tf_deberta_v2.py @@ -620,11 +620,15 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout") def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor: - shape = shape_list(tensor)[:-1] + [attention_heads, -1] + tensor_shape = shape_list(tensor) + # In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None + shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads] # Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size] tensor = tf.reshape(tensor=tensor, shape=shape) + tensor = tf.transpose(tensor, perm=[0, 2, 1, 3]) x_shape = shape_list(tensor) - return tf.reshape(tf.transpose(tensor, perm=[0, 2, 1, 3]), shape=[-1, x_shape[1], x_shape[-1]]) + tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]]) + return tensor def call( self, @@ -686,7 +690,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): if rel_att is not None: attention_scores = attention_scores + rel_att - attention_scores = attention_scores attention_scores = tf.reshape( attention_scores, (-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]), @@ -706,9 +709,12 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer): ), [0, 2, 1, 3], ) - new_context_layer_shape = shape_list(context_layer)[:-2] + [ - -1, - ] + # Set the final dimension here explicitly. + # Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing + # the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput + # requires final input dimension to be defined + context_layer_shape = shape_list(context_layer) + new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]] context_layer = tf.reshape(context_layer, new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs diff --git a/src/transformers/models/swin/modeling_tf_swin.py b/src/transformers/models/swin/modeling_tf_swin.py index 68029074d3..45c555ee18 100644 --- a/src/transformers/models/swin/modeling_tf_swin.py +++ b/src/transformers/models/swin/modeling_tf_swin.py @@ -213,7 +213,7 @@ def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor: """ Partitions the given input into windows. """ - batch_size, height, width, num_channels = input_feature.shape + batch_size, height, width, num_channels = shape_list(input_feature) input_feature = tf.reshape( input_feature, (batch_size, height // window_size, window_size, width // window_size, window_size, num_channels), @@ -227,7 +227,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int """ Merges windows to produce higher resolution features. """ - batch_size = int(windows.shape[0] / (height * width / window_size / window_size)) + x = shape_list(windows)[0] + y = tf.cast(height * width / window_size / window_size, tf.int32) + batch_size = int(x / y) windows = tf.reshape( windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1) ) @@ -245,7 +247,9 @@ def drop_path( if drop_prob == 0.0 or not training: return input keep_prob = 1 - drop_prob - shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + input_shape = shape_list(input) + ndim = len(input_shape) + shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets random_tensor = tf.random.uniform(shape) random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0) if keep_prob > 0.0 and scale_by_keep: @@ -295,7 +299,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer): ) -> Tuple[tf.Tensor, Tuple[int, int]]: embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training) embeddings = self.norm(embeddings, training=training) - batch_size, seq_len, _ = embeddings.shape + batch_size, seq_len, _ = shape_list(embeddings) if bool_masked_pos is not None: mask_tokens = tf.repeat(self.mask_token, batch_size, 0) @@ -357,10 +361,10 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer): # B,H,W,C -> B,C,H,W embeddings = tf.transpose(embeddings, (0, 3, 1, 2)) - _, _, height, width = embeddings.shape + batch_size, channels, height, width = shape_list(embeddings) output_dimensions = (height, width) - embeddings = tf.reshape(embeddings, (embeddings.shape[0], embeddings.shape[1], -1)) + embeddings = tf.reshape(embeddings, (batch_size, channels, -1)) embeddings = tf.transpose(embeddings, (0, 2, 1)) return embeddings, output_dimensions @@ -402,7 +406,7 @@ class TFSwinPatchMerging(tf.keras.layers.Layer): def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor: height, width = input_dimensions # `dim` is height * width - batch_size, _, num_channels = input_feature.shape + batch_size, _, num_channels = shape_list(input_feature) input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels)) # pad input to be disible by width and height, if needed @@ -456,7 +460,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): coords_h = tf.range(self.window_size[0]) coords_w = tf.range(self.window_size[1]) coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij")) - coords_flatten = tf.reshape(coords, (coords.shape[0], -1)) + coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1)) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = tf.transpose(relative_coords, (1, 2, 0)) @@ -497,7 +501,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): super().build(input_shape) def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor: - new_x_shape = x.shape[:-1] + (self.num_attention_heads, self.attention_head_size) + new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size] x = tf.reshape(x, new_x_shape) return tf.transpose(x, (0, 2, 1, 3)) @@ -509,7 +513,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): output_attentions: bool = False, training: bool = False, ) -> Tuple[tf.Tensor, ...]: - batch_size, dim, _ = hidden_states.shape + batch_size, dim, _ = shape_list(hidden_states) mixed_query_layer = self.query(hidden_states) key_layer = self.transpose_for_scores(self.key(hidden_states)) @@ -533,7 +537,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): if attention_mask is not None: # Apply the attention mask is (precomputed for all layers in SwinModel forward() function) - mask_shape = attention_mask.shape[0] + mask_shape = shape_list(attention_mask)[0] attention_scores = tf.reshape( attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim) ) @@ -555,7 +559,9 @@ class TFSwinSelfAttention(tf.keras.layers.Layer): context_layer = tf.matmul(attention_probs, value_layer) context_layer = tf.transpose(context_layer, (0, 2, 1, 3)) - new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,) + new_context_layer_shape = shape_list(context_layer)[:-2] + [ + self.all_head_size, + ] context_layer = tf.reshape(context_layer, new_context_layer_shape) outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) @@ -720,7 +726,7 @@ class TFSwinLayer(tf.keras.layers.Layer): ) -> tf.Tensor: self.set_shift_and_window_size(input_dimensions) height, width = input_dimensions - batch_size, _, channels = hidden_states.shape + batch_size, _, channels = shape_list(hidden_states) shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states, training=training) @@ -728,7 +734,7 @@ class TFSwinLayer(tf.keras.layers.Layer): # pad hidden_states to multiples of window size hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) - _, height_pad, width_pad, _ = hidden_states.shape + _, height_pad, width_pad, _ = shape_list(hidden_states) # cyclic shift if self.shift_size > 0: shifted_hidden_states = tf.roll(hidden_states, shift=(-self.shift_size, -self.shift_size), axis=(1, 2)) @@ -881,7 +887,7 @@ class TFSwinEncoder(tf.keras.layers.Layer): all_self_attentions = () if output_attentions else None if output_hidden_states: - batch_size, _, hidden_size = hidden_states.shape + batch_size, _, hidden_size = shape_list(hidden_states) # rearrange b (h w) c -> b c h w reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) @@ -902,7 +908,7 @@ class TFSwinEncoder(tf.keras.layers.Layer): all_input_dimensions += (input_dimensions,) if output_hidden_states: - batch_size, _, hidden_size = hidden_states.shape + batch_size, _, hidden_size = shape_list(hidden_states) # rearrange b (h w) c -> b c h w reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size)) reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2)) @@ -1152,7 +1158,7 @@ class TFSwinModel(TFSwinPreTrainedModel): pooled_output = None if self.pooler is not None: - batch_size, _, num_features = sequence_output.shape + batch_size, _, num_features = shape_list(sequence_output) pooled_output = self.pooler(sequence_output) pooled_output = tf.reshape(pooled_output, (batch_size, num_features)) @@ -1206,7 +1212,7 @@ class TFSwinDecoder(tf.keras.layers.Layer): # B,C,H,W -> B,H,W,C hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1)) hidden_states = self.conv2d(hidden_states) - batch_size, _, _, num_input_channels = hidden_states.shape + batch_size, _, _, num_input_channels = shape_list(hidden_states) block_size_squared = self._block_size**2 output_depth = int(num_input_channels / block_size_squared) # When the number of output channels >= 2, PyTorch's PixelShuffle and @@ -1293,7 +1299,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel): sequence_output = outputs[0] # Reshape to (batch_size, num_channels, height, width) sequence_output = tf.transpose(sequence_output, (0, 2, 1)) - batch_size, num_channels, sequence_length = sequence_output.shape + batch_size, num_channels, sequence_length = shape_list(sequence_output) height = width = int(sequence_length**0.5) sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width)) diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index fa439704a8..908d072220 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1406,6 +1406,24 @@ class TFModelTesterMixin: if metrics: self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!") + # Make sure fit works with tf.data.Dataset and results are consistent + dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class) + # Pass in all samples as a batch to match other `fit` calls + dataset = dataset.batch(len(dataset)) + history3 = model.fit( + dataset, + validation_data=dataset, + steps_per_epoch=1, + validation_steps=1, + shuffle=False, + ) + val_loss3 = history3.history["val_loss"][0] + accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")} + self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3)) + self.assertEqual(history1.history.keys(), history3.history.keys()) + if metrics: + self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!") + def test_int64_inputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: