Use shape_list to safely get shapes for Swin (#17591)
* Use shape_list to safely get shapes * Add relevant test * Tidy and add metrics * Resolve dynamic shaping issues and move test * Tidy up and all samples in batch * Formatting
This commit is contained in:
@@ -648,7 +648,12 @@ class TFDebertaDisentangledSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
context_layer = tf.matmul(attention_probs, value_layer)
|
context_layer = tf.matmul(attention_probs, value_layer)
|
||||||
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
|
context_layer = tf.transpose(context_layer, [0, 2, 1, 3])
|
||||||
new_context_layer_shape = shape_list(context_layer)[:-2] + [-1]
|
context_layer_shape = shape_list(context_layer)
|
||||||
|
# Set the final dimension here explicitly.
|
||||||
|
# Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
|
||||||
|
# the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
|
||||||
|
# requires final input dimension to be defined
|
||||||
|
new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
|
||||||
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -620,11 +620,15 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
|||||||
self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
|
self.dropout = TFDebertaV2StableDropout(config.attention_probs_dropout_prob, name="dropout")
|
||||||
|
|
||||||
def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
|
def transpose_for_scores(self, tensor: tf.Tensor, attention_heads: int) -> tf.Tensor:
|
||||||
shape = shape_list(tensor)[:-1] + [attention_heads, -1]
|
tensor_shape = shape_list(tensor)
|
||||||
|
# In graph mode mode, we can't reshape with -1 as the final dimension if the first dimension (batch size) is None
|
||||||
|
shape = tensor_shape[:-1] + [attention_heads, tensor_shape[-1] // attention_heads]
|
||||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||||
tensor = tf.reshape(tensor=tensor, shape=shape)
|
tensor = tf.reshape(tensor=tensor, shape=shape)
|
||||||
|
tensor = tf.transpose(tensor, perm=[0, 2, 1, 3])
|
||||||
x_shape = shape_list(tensor)
|
x_shape = shape_list(tensor)
|
||||||
return tf.reshape(tf.transpose(tensor, perm=[0, 2, 1, 3]), shape=[-1, x_shape[1], x_shape[-1]])
|
tensor = tf.reshape(tensor, shape=[-1, x_shape[-2], x_shape[-1]])
|
||||||
|
return tensor
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
@@ -686,7 +690,6 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if rel_att is not None:
|
if rel_att is not None:
|
||||||
attention_scores = attention_scores + rel_att
|
attention_scores = attention_scores + rel_att
|
||||||
attention_scores = attention_scores
|
|
||||||
attention_scores = tf.reshape(
|
attention_scores = tf.reshape(
|
||||||
attention_scores,
|
attention_scores,
|
||||||
(-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]),
|
(-1, self.num_attention_heads, shape_list(attention_scores)[-2], shape_list(attention_scores)[-1]),
|
||||||
@@ -706,9 +709,12 @@ class TFDebertaV2DisentangledSelfAttention(tf.keras.layers.Layer):
|
|||||||
),
|
),
|
||||||
[0, 2, 1, 3],
|
[0, 2, 1, 3],
|
||||||
)
|
)
|
||||||
new_context_layer_shape = shape_list(context_layer)[:-2] + [
|
# Set the final dimension here explicitly.
|
||||||
-1,
|
# Calling tf.reshape(context_layer, (*context_layer_shape[:-2], -1)) raises an error when executing
|
||||||
]
|
# the model in graph mode as context_layer is reshaped to (None, 7, None) and Dense layer in TFDebertaV2SelfOutput
|
||||||
|
# requires final input dimension to be defined
|
||||||
|
context_layer_shape = shape_list(context_layer)
|
||||||
|
new_context_layer_shape = context_layer_shape[:-2] + [context_layer_shape[-2] * context_layer_shape[-1]]
|
||||||
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|||||||
@@ -213,7 +213,7 @@ def window_partition(input_feature: tf.Tensor, window_size: int) -> tf.Tensor:
|
|||||||
"""
|
"""
|
||||||
Partitions the given input into windows.
|
Partitions the given input into windows.
|
||||||
"""
|
"""
|
||||||
batch_size, height, width, num_channels = input_feature.shape
|
batch_size, height, width, num_channels = shape_list(input_feature)
|
||||||
input_feature = tf.reshape(
|
input_feature = tf.reshape(
|
||||||
input_feature,
|
input_feature,
|
||||||
(batch_size, height // window_size, window_size, width // window_size, window_size, num_channels),
|
(batch_size, height // window_size, window_size, width // window_size, window_size, num_channels),
|
||||||
@@ -227,7 +227,9 @@ def window_reverse(windows: tf.Tensor, window_size: int, height: int, width: int
|
|||||||
"""
|
"""
|
||||||
Merges windows to produce higher resolution features.
|
Merges windows to produce higher resolution features.
|
||||||
"""
|
"""
|
||||||
batch_size = int(windows.shape[0] / (height * width / window_size / window_size))
|
x = shape_list(windows)[0]
|
||||||
|
y = tf.cast(height * width / window_size / window_size, tf.int32)
|
||||||
|
batch_size = int(x / y)
|
||||||
windows = tf.reshape(
|
windows = tf.reshape(
|
||||||
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
windows, (batch_size, height // window_size, width // window_size, window_size, window_size, -1)
|
||||||
)
|
)
|
||||||
@@ -245,7 +247,9 @@ def drop_path(
|
|||||||
if drop_prob == 0.0 or not training:
|
if drop_prob == 0.0 or not training:
|
||||||
return input
|
return input
|
||||||
keep_prob = 1 - drop_prob
|
keep_prob = 1 - drop_prob
|
||||||
shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
input_shape = shape_list(input)
|
||||||
|
ndim = len(input_shape)
|
||||||
|
shape = [input_shape[0]] + [1] * (ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||||
random_tensor = tf.random.uniform(shape)
|
random_tensor = tf.random.uniform(shape)
|
||||||
random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0)
|
random_tensor = tf.where(random_tensor <= keep_prob, 1.0, 0.0)
|
||||||
if keep_prob > 0.0 and scale_by_keep:
|
if keep_prob > 0.0 and scale_by_keep:
|
||||||
@@ -295,7 +299,7 @@ class TFSwinEmbeddings(tf.keras.layers.Layer):
|
|||||||
) -> Tuple[tf.Tensor, Tuple[int, int]]:
|
) -> Tuple[tf.Tensor, Tuple[int, int]]:
|
||||||
embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training)
|
embeddings, output_dimensions = self.patch_embeddings(pixel_values, training=training)
|
||||||
embeddings = self.norm(embeddings, training=training)
|
embeddings = self.norm(embeddings, training=training)
|
||||||
batch_size, seq_len, _ = embeddings.shape
|
batch_size, seq_len, _ = shape_list(embeddings)
|
||||||
|
|
||||||
if bool_masked_pos is not None:
|
if bool_masked_pos is not None:
|
||||||
mask_tokens = tf.repeat(self.mask_token, batch_size, 0)
|
mask_tokens = tf.repeat(self.mask_token, batch_size, 0)
|
||||||
@@ -357,10 +361,10 @@ class TFSwinPatchEmbeddings(tf.keras.layers.Layer):
|
|||||||
# B,H,W,C -> B,C,H,W
|
# B,H,W,C -> B,C,H,W
|
||||||
embeddings = tf.transpose(embeddings, (0, 3, 1, 2))
|
embeddings = tf.transpose(embeddings, (0, 3, 1, 2))
|
||||||
|
|
||||||
_, _, height, width = embeddings.shape
|
batch_size, channels, height, width = shape_list(embeddings)
|
||||||
output_dimensions = (height, width)
|
output_dimensions = (height, width)
|
||||||
|
|
||||||
embeddings = tf.reshape(embeddings, (embeddings.shape[0], embeddings.shape[1], -1))
|
embeddings = tf.reshape(embeddings, (batch_size, channels, -1))
|
||||||
embeddings = tf.transpose(embeddings, (0, 2, 1))
|
embeddings = tf.transpose(embeddings, (0, 2, 1))
|
||||||
return embeddings, output_dimensions
|
return embeddings, output_dimensions
|
||||||
|
|
||||||
@@ -402,7 +406,7 @@ class TFSwinPatchMerging(tf.keras.layers.Layer):
|
|||||||
def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor:
|
def call(self, input_feature: tf.Tensor, input_dimensions: Tuple[int, int], training: bool = False) -> tf.Tensor:
|
||||||
height, width = input_dimensions
|
height, width = input_dimensions
|
||||||
# `dim` is height * width
|
# `dim` is height * width
|
||||||
batch_size, _, num_channels = input_feature.shape
|
batch_size, _, num_channels = shape_list(input_feature)
|
||||||
|
|
||||||
input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels))
|
input_feature = tf.reshape(input_feature, (batch_size, height, width, num_channels))
|
||||||
# pad input to be disible by width and height, if needed
|
# pad input to be disible by width and height, if needed
|
||||||
@@ -456,7 +460,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
|||||||
coords_h = tf.range(self.window_size[0])
|
coords_h = tf.range(self.window_size[0])
|
||||||
coords_w = tf.range(self.window_size[1])
|
coords_w = tf.range(self.window_size[1])
|
||||||
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
|
coords = tf.stack(tf.meshgrid(coords_h, coords_w, indexing="ij"))
|
||||||
coords_flatten = tf.reshape(coords, (coords.shape[0], -1))
|
coords_flatten = tf.reshape(coords, (shape_list(coords)[0], -1))
|
||||||
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||||
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
|
relative_coords = tf.transpose(relative_coords, (1, 2, 0))
|
||||||
|
|
||||||
@@ -497,7 +501,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
|||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
|
def transpose_for_scores(self, x: tf.Tensor) -> tf.Tensor:
|
||||||
new_x_shape = x.shape[:-1] + (self.num_attention_heads, self.attention_head_size)
|
new_x_shape = shape_list(x)[:-1] + [self.num_attention_heads, self.attention_head_size]
|
||||||
x = tf.reshape(x, new_x_shape)
|
x = tf.reshape(x, new_x_shape)
|
||||||
return tf.transpose(x, (0, 2, 1, 3))
|
return tf.transpose(x, (0, 2, 1, 3))
|
||||||
|
|
||||||
@@ -509,7 +513,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
|||||||
output_attentions: bool = False,
|
output_attentions: bool = False,
|
||||||
training: bool = False,
|
training: bool = False,
|
||||||
) -> Tuple[tf.Tensor, ...]:
|
) -> Tuple[tf.Tensor, ...]:
|
||||||
batch_size, dim, _ = hidden_states.shape
|
batch_size, dim, _ = shape_list(hidden_states)
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
|
|
||||||
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
||||||
@@ -533,7 +537,7 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
|
# Apply the attention mask is (precomputed for all layers in SwinModel forward() function)
|
||||||
mask_shape = attention_mask.shape[0]
|
mask_shape = shape_list(attention_mask)[0]
|
||||||
attention_scores = tf.reshape(
|
attention_scores = tf.reshape(
|
||||||
attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim)
|
attention_scores, (batch_size // mask_shape, mask_shape, self.num_attention_heads, dim, dim)
|
||||||
)
|
)
|
||||||
@@ -555,7 +559,9 @@ class TFSwinSelfAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
context_layer = tf.matmul(attention_probs, value_layer)
|
context_layer = tf.matmul(attention_probs, value_layer)
|
||||||
context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
|
context_layer = tf.transpose(context_layer, (0, 2, 1, 3))
|
||||||
new_context_layer_shape = context_layer.shape[:-2] + (self.all_head_size,)
|
new_context_layer_shape = shape_list(context_layer)[:-2] + [
|
||||||
|
self.all_head_size,
|
||||||
|
]
|
||||||
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
context_layer = tf.reshape(context_layer, new_context_layer_shape)
|
||||||
|
|
||||||
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
@@ -720,7 +726,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
|
|||||||
) -> tf.Tensor:
|
) -> tf.Tensor:
|
||||||
self.set_shift_and_window_size(input_dimensions)
|
self.set_shift_and_window_size(input_dimensions)
|
||||||
height, width = input_dimensions
|
height, width = input_dimensions
|
||||||
batch_size, _, channels = hidden_states.shape
|
batch_size, _, channels = shape_list(hidden_states)
|
||||||
shortcut = hidden_states
|
shortcut = hidden_states
|
||||||
|
|
||||||
hidden_states = self.layernorm_before(hidden_states, training=training)
|
hidden_states = self.layernorm_before(hidden_states, training=training)
|
||||||
@@ -728,7 +734,7 @@ class TFSwinLayer(tf.keras.layers.Layer):
|
|||||||
# pad hidden_states to multiples of window size
|
# pad hidden_states to multiples of window size
|
||||||
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
hidden_states, pad_values = self.maybe_pad(hidden_states, height, width)
|
||||||
|
|
||||||
_, height_pad, width_pad, _ = hidden_states.shape
|
_, height_pad, width_pad, _ = shape_list(hidden_states)
|
||||||
# cyclic shift
|
# cyclic shift
|
||||||
if self.shift_size > 0:
|
if self.shift_size > 0:
|
||||||
shifted_hidden_states = tf.roll(hidden_states, shift=(-self.shift_size, -self.shift_size), axis=(1, 2))
|
shifted_hidden_states = tf.roll(hidden_states, shift=(-self.shift_size, -self.shift_size), axis=(1, 2))
|
||||||
@@ -881,7 +887,7 @@ class TFSwinEncoder(tf.keras.layers.Layer):
|
|||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
batch_size, _, hidden_size = hidden_states.shape
|
batch_size, _, hidden_size = shape_list(hidden_states)
|
||||||
# rearrange b (h w) c -> b c h w
|
# rearrange b (h w) c -> b c h w
|
||||||
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
|
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
|
||||||
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
|
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
|
||||||
@@ -902,7 +908,7 @@ class TFSwinEncoder(tf.keras.layers.Layer):
|
|||||||
all_input_dimensions += (input_dimensions,)
|
all_input_dimensions += (input_dimensions,)
|
||||||
|
|
||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
batch_size, _, hidden_size = hidden_states.shape
|
batch_size, _, hidden_size = shape_list(hidden_states)
|
||||||
# rearrange b (h w) c -> b c h w
|
# rearrange b (h w) c -> b c h w
|
||||||
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
|
reshaped_hidden_state = tf.reshape(hidden_states, (batch_size, *input_dimensions, hidden_size))
|
||||||
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
|
reshaped_hidden_state = tf.transpose(reshaped_hidden_state, (0, 3, 1, 2))
|
||||||
@@ -1152,7 +1158,7 @@ class TFSwinModel(TFSwinPreTrainedModel):
|
|||||||
|
|
||||||
pooled_output = None
|
pooled_output = None
|
||||||
if self.pooler is not None:
|
if self.pooler is not None:
|
||||||
batch_size, _, num_features = sequence_output.shape
|
batch_size, _, num_features = shape_list(sequence_output)
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
pooled_output = tf.reshape(pooled_output, (batch_size, num_features))
|
pooled_output = tf.reshape(pooled_output, (batch_size, num_features))
|
||||||
|
|
||||||
@@ -1206,7 +1212,7 @@ class TFSwinDecoder(tf.keras.layers.Layer):
|
|||||||
# B,C,H,W -> B,H,W,C
|
# B,C,H,W -> B,H,W,C
|
||||||
hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1))
|
hidden_states = tf.transpose(hidden_states, (0, 2, 3, 1))
|
||||||
hidden_states = self.conv2d(hidden_states)
|
hidden_states = self.conv2d(hidden_states)
|
||||||
batch_size, _, _, num_input_channels = hidden_states.shape
|
batch_size, _, _, num_input_channels = shape_list(hidden_states)
|
||||||
block_size_squared = self._block_size**2
|
block_size_squared = self._block_size**2
|
||||||
output_depth = int(num_input_channels / block_size_squared)
|
output_depth = int(num_input_channels / block_size_squared)
|
||||||
# When the number of output channels >= 2, PyTorch's PixelShuffle and
|
# When the number of output channels >= 2, PyTorch's PixelShuffle and
|
||||||
@@ -1293,7 +1299,7 @@ class TFSwinForMaskedImageModeling(TFSwinPreTrainedModel):
|
|||||||
sequence_output = outputs[0]
|
sequence_output = outputs[0]
|
||||||
# Reshape to (batch_size, num_channels, height, width)
|
# Reshape to (batch_size, num_channels, height, width)
|
||||||
sequence_output = tf.transpose(sequence_output, (0, 2, 1))
|
sequence_output = tf.transpose(sequence_output, (0, 2, 1))
|
||||||
batch_size, num_channels, sequence_length = sequence_output.shape
|
batch_size, num_channels, sequence_length = shape_list(sequence_output)
|
||||||
height = width = int(sequence_length**0.5)
|
height = width = int(sequence_length**0.5)
|
||||||
sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width))
|
sequence_output = tf.reshape(sequence_output, (batch_size, num_channels, height, width))
|
||||||
|
|
||||||
|
|||||||
@@ -1406,6 +1406,24 @@ class TFModelTesterMixin:
|
|||||||
if metrics:
|
if metrics:
|
||||||
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
|
self.assertTrue(len(accuracy1) == len(accuracy2) > 0, "Missing metrics!")
|
||||||
|
|
||||||
|
# Make sure fit works with tf.data.Dataset and results are consistent
|
||||||
|
dataset = tf.data.Dataset.from_tensor_slices(prepared_for_class)
|
||||||
|
# Pass in all samples as a batch to match other `fit` calls
|
||||||
|
dataset = dataset.batch(len(dataset))
|
||||||
|
history3 = model.fit(
|
||||||
|
dataset,
|
||||||
|
validation_data=dataset,
|
||||||
|
steps_per_epoch=1,
|
||||||
|
validation_steps=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
val_loss3 = history3.history["val_loss"][0]
|
||||||
|
accuracy3 = {key: val[0] for key, val in history3.history.items() if key.endswith("accuracy")}
|
||||||
|
self.assertTrue(np.allclose(val_loss1, val_loss3, atol=1e-2, rtol=1e-3))
|
||||||
|
self.assertEqual(history1.history.keys(), history3.history.keys())
|
||||||
|
if metrics:
|
||||||
|
self.assertTrue(len(accuracy1) == len(accuracy3) > 0, "Missing metrics!")
|
||||||
|
|
||||||
def test_int64_inputs(self):
|
def test_int64_inputs(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|||||||
Reference in New Issue
Block a user