Correct NATTEN function signatures and force new version (#22298)
This commit is contained in:
2
setup.py
2
setup.py
@@ -129,7 +129,7 @@ _deps = [
|
|||||||
"keras-nlp>=0.3.1",
|
"keras-nlp>=0.3.1",
|
||||||
"librosa",
|
"librosa",
|
||||||
"nltk",
|
"nltk",
|
||||||
"natten>=0.14.5",
|
"natten>=0.14.6",
|
||||||
"numpy>=1.17",
|
"numpy>=1.17",
|
||||||
"onnxconverter-common",
|
"onnxconverter-common",
|
||||||
"onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools>=1.4.2",
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ deps = {
|
|||||||
"keras-nlp": "keras-nlp>=0.3.1",
|
"keras-nlp": "keras-nlp>=0.3.1",
|
||||||
"librosa": "librosa",
|
"librosa": "librosa",
|
||||||
"nltk": "nltk",
|
"nltk": "nltk",
|
||||||
"natten": "natten>=0.14.5",
|
"natten": "natten>=0.14.6",
|
||||||
"numpy": "numpy>=1.17",
|
"numpy": "numpy>=1.17",
|
||||||
"onnxconverter-common": "onnxconverter-common",
|
"onnxconverter-common": "onnxconverter-common",
|
||||||
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
"onnxruntime-tools": "onnxruntime-tools>=1.4.2",
|
||||||
|
|||||||
@@ -356,7 +356,7 @@ class NeighborhoodAttention(nn.Module):
|
|||||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
attention_probs = self.dropout(attention_probs)
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
context_layer = natten2dav(attention_probs, value_layer, self.dilation)
|
context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, self.dilation)
|
||||||
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
|
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|||||||
@@ -348,7 +348,7 @@ class NeighborhoodAttention(nn.Module):
|
|||||||
# seem a bit unusual, but is taken from the original Transformer paper.
|
# seem a bit unusual, but is taken from the original Transformer paper.
|
||||||
attention_probs = self.dropout(attention_probs)
|
attention_probs = self.dropout(attention_probs)
|
||||||
|
|
||||||
context_layer = natten2dav(attention_probs, value_layer, 1)
|
context_layer = natten2dav(attention_probs, value_layer, self.kernel_size, 1)
|
||||||
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
|
context_layer = context_layer.permute(0, 2, 3, 1, 4).contiguous()
|
||||||
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
||||||
context_layer = context_layer.view(new_context_layer_shape)
|
context_layer = context_layer.view(new_context_layer_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user