Transformer-XL: Remove unused parameters (#7087)
* Removed 'tgt_len' and 'ext_len' from Transfomer-XL * Some changes are still to be done * Removed 'tgt_len' and 'ext_len' from Transfomer-XL (2) * Removed comments * Fixed quality * Changed warning to info
This commit is contained in:
@@ -88,7 +88,7 @@ def main():
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
model.reset_memory_length(args.mem_len)
|
||||||
if args.clamp_len > 0:
|
if args.clamp_len > 0:
|
||||||
model.clamp_len = args.clamp_len
|
model.clamp_len = args.clamp_len
|
||||||
if args.same_length:
|
if args.same_length:
|
||||||
|
|||||||
@@ -62,10 +62,6 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
Apply LayerNorm to the input instead of the output
|
Apply LayerNorm to the input instead of the output
|
||||||
n_layer (:obj:`int`, optional, defaults to 18):
|
n_layer (:obj:`int`, optional, defaults to 18):
|
||||||
Number of hidden layers in the Transformer encoder.
|
Number of hidden layers in the Transformer encoder.
|
||||||
tgt_len (:obj:`int`, optional, defaults to 128):
|
|
||||||
Number of tokens to predict
|
|
||||||
ext_len (:obj:`int`, optional, defaults to 0):
|
|
||||||
Length of the extended context
|
|
||||||
mem_len (:obj:`int`, optional, defaults to 1600):
|
mem_len (:obj:`int`, optional, defaults to 1600):
|
||||||
Length of the retained previous heads
|
Length of the retained previous heads
|
||||||
clamp_len (:obj:`int`, optional, defaults to 1000):
|
clamp_len (:obj:`int`, optional, defaults to 1000):
|
||||||
@@ -125,8 +121,6 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
div_val=4,
|
div_val=4,
|
||||||
pre_lnorm=False,
|
pre_lnorm=False,
|
||||||
n_layer=18,
|
n_layer=18,
|
||||||
tgt_len=128,
|
|
||||||
ext_len=0,
|
|
||||||
mem_len=1600,
|
mem_len=1600,
|
||||||
clamp_len=1000,
|
clamp_len=1000,
|
||||||
same_length=True,
|
same_length=True,
|
||||||
@@ -168,8 +162,6 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
self.pre_lnorm = pre_lnorm
|
self.pre_lnorm = pre_lnorm
|
||||||
self.n_layer = n_layer
|
self.n_layer = n_layer
|
||||||
self.n_head = n_head
|
self.n_head = n_head
|
||||||
self.tgt_len = tgt_len
|
|
||||||
self.ext_len = ext_len
|
|
||||||
self.mem_len = mem_len
|
self.mem_len = mem_len
|
||||||
self.same_length = same_length
|
self.same_length = same_length
|
||||||
self.attn_type = attn_type
|
self.attn_type = attn_type
|
||||||
@@ -187,7 +179,9 @@ class TransfoXLConfig(PretrainedConfig):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def max_position_embeddings(self):
|
def max_position_embeddings(self):
|
||||||
return self.tgt_len + self.ext_len + self.mem_len
|
# Message copied from Transformer-XL documentation
|
||||||
|
logger.info(f"The model {self.model_type} is one of the few models that has no sequence length limit.")
|
||||||
|
return -1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def n_token(self): # Backward compatibility
|
def n_token(self): # Backward compatibility
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" TF 2.0 Transformer XL model.
|
""" TF 2.0 Transformer XL model.
|
||||||
"""
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -107,10 +106,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
d_model,
|
d_model,
|
||||||
d_head,
|
d_head,
|
||||||
dropout,
|
dropout,
|
||||||
dropatt=0,
|
dropatt=0.0,
|
||||||
tgt_len=None,
|
|
||||||
ext_len=None,
|
|
||||||
mem_len=None,
|
|
||||||
pre_lnorm=False,
|
pre_lnorm=False,
|
||||||
r_r_bias=None,
|
r_r_bias=None,
|
||||||
r_w_bias=None,
|
r_w_bias=None,
|
||||||
@@ -261,9 +257,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||||||
d_head,
|
d_head,
|
||||||
d_inner,
|
d_inner,
|
||||||
dropout,
|
dropout,
|
||||||
tgt_len=None,
|
|
||||||
ext_len=None,
|
|
||||||
mem_len=None,
|
|
||||||
dropatt=0.0,
|
dropatt=0.0,
|
||||||
pre_lnorm=False,
|
pre_lnorm=False,
|
||||||
r_w_bias=None,
|
r_w_bias=None,
|
||||||
@@ -280,9 +273,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||||||
d_model,
|
d_model,
|
||||||
d_head,
|
d_head,
|
||||||
dropout,
|
dropout,
|
||||||
tgt_len=tgt_len,
|
|
||||||
ext_len=ext_len,
|
|
||||||
mem_len=mem_len,
|
|
||||||
dropatt=dropatt,
|
dropatt=dropatt,
|
||||||
pre_lnorm=pre_lnorm,
|
pre_lnorm=pre_lnorm,
|
||||||
r_w_bias=r_w_bias,
|
r_w_bias=r_w_bias,
|
||||||
@@ -414,12 +404,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
self.drop = tf.keras.layers.Dropout(config.dropout)
|
self.drop = tf.keras.layers.Dropout(config.dropout)
|
||||||
|
|
||||||
self.n_layer = config.n_layer
|
self.n_layer = config.n_layer
|
||||||
|
|
||||||
self.tgt_len = config.tgt_len
|
|
||||||
self.mem_len = config.mem_len
|
self.mem_len = config.mem_len
|
||||||
self.ext_len = config.ext_len
|
|
||||||
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
|
|
||||||
|
|
||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
|
|
||||||
self.layers = []
|
self.layers = []
|
||||||
@@ -432,9 +417,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
config.d_head,
|
config.d_head,
|
||||||
config.d_inner,
|
config.d_inner,
|
||||||
config.dropout,
|
config.dropout,
|
||||||
tgt_len=config.tgt_len,
|
|
||||||
ext_len=config.ext_len,
|
|
||||||
mem_len=config.mem_len,
|
|
||||||
dropatt=config.dropatt,
|
dropatt=config.dropatt,
|
||||||
pre_lnorm=config.pre_lnorm,
|
pre_lnorm=config.pre_lnorm,
|
||||||
r_w_bias=None if self.untie_r else self.r_w_bias,
|
r_w_bias=None if self.untie_r else self.r_w_bias,
|
||||||
@@ -478,10 +460,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
def backward_compatible(self):
|
def backward_compatible(self):
|
||||||
self.sample_softmax = -1
|
self.sample_softmax = -1
|
||||||
|
|
||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_memory_length(self, mem_len):
|
||||||
self.tgt_len = tgt_len
|
|
||||||
self.mem_len = mem_len
|
self.mem_len = mem_len
|
||||||
self.ext_len = ext_len
|
|
||||||
|
|
||||||
def _prune_heads(self, heads):
|
def _prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
@@ -506,12 +486,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
assert len(hids) == len(mems), "len(hids) != len(mems)"
|
assert len(hids) == len(mems), "len(hids) != len(mems)"
|
||||||
|
|
||||||
# There are `mlen + qlen` steps that can be cached into mems
|
# There are `mlen + qlen` steps that can be cached into mems
|
||||||
# For the next step, the last `ext_len` of the `qlen` tokens
|
|
||||||
# will be used as the extended context. Hence, we only cache
|
|
||||||
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
|
||||||
# to `mlen + qlen - self.ext_len`.
|
|
||||||
new_mems = []
|
new_mems = []
|
||||||
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
end_idx = mlen + max(0, qlen)
|
||||||
beg_idx = max(0, end_idx - self.mem_len)
|
beg_idx = max(0, end_idx - self.mem_len)
|
||||||
for i in range(len(hids)):
|
for i in range(len(hids)):
|
||||||
|
|
||||||
@@ -867,7 +843,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
warnings.warn(
|
||||||
|
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.transformer.reset_memory_length(mem_len)
|
||||||
|
|
||||||
|
def reset_memory_length(self, mem_len):
|
||||||
|
self.transformer.reset_memory_length(mem_len)
|
||||||
|
|
||||||
def init_mems(self, bsz):
|
def init_mems(self, bsz):
|
||||||
return self.transformer.init_mems(bsz)
|
return self.transformer.init_mems(bsz)
|
||||||
|
|||||||
@@ -17,8 +17,7 @@
|
|||||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
||||||
"""
|
"""
|
||||||
|
import warnings
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@@ -234,9 +233,6 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
|||||||
d_head,
|
d_head,
|
||||||
dropout,
|
dropout,
|
||||||
dropatt=0,
|
dropatt=0,
|
||||||
tgt_len=None,
|
|
||||||
ext_len=None,
|
|
||||||
mem_len=None,
|
|
||||||
pre_lnorm=False,
|
pre_lnorm=False,
|
||||||
r_r_bias=None,
|
r_r_bias=None,
|
||||||
r_w_bias=None,
|
r_w_bias=None,
|
||||||
@@ -737,12 +733,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
self.drop = nn.Dropout(config.dropout)
|
self.drop = nn.Dropout(config.dropout)
|
||||||
|
|
||||||
self.n_layer = config.n_layer
|
self.n_layer = config.n_layer
|
||||||
|
|
||||||
self.tgt_len = config.tgt_len
|
|
||||||
self.mem_len = config.mem_len
|
self.mem_len = config.mem_len
|
||||||
self.ext_len = config.ext_len
|
|
||||||
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
|
|
||||||
|
|
||||||
self.attn_type = config.attn_type
|
self.attn_type = config.attn_type
|
||||||
|
|
||||||
if not config.untie_r:
|
if not config.untie_r:
|
||||||
@@ -759,9 +750,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
config.d_head,
|
config.d_head,
|
||||||
config.d_inner,
|
config.d_inner,
|
||||||
config.dropout,
|
config.dropout,
|
||||||
tgt_len=config.tgt_len,
|
|
||||||
ext_len=config.ext_len,
|
|
||||||
mem_len=config.mem_len,
|
|
||||||
dropatt=config.dropatt,
|
dropatt=config.dropatt,
|
||||||
pre_lnorm=config.pre_lnorm,
|
pre_lnorm=config.pre_lnorm,
|
||||||
r_w_bias=None if config.untie_r else self.r_w_bias,
|
r_w_bias=None if config.untie_r else self.r_w_bias,
|
||||||
@@ -791,10 +779,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
def backward_compatible(self):
|
def backward_compatible(self):
|
||||||
self.sample_softmax = -1
|
self.sample_softmax = -1
|
||||||
|
|
||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_memory_length(self, mem_len):
|
||||||
self.tgt_len = tgt_len
|
|
||||||
self.mem_len = mem_len
|
self.mem_len = mem_len
|
||||||
self.ext_len = ext_len
|
|
||||||
|
|
||||||
def _prune_heads(self, heads):
|
def _prune_heads(self, heads):
|
||||||
logger.info("Head pruning is not implemented for Transformer-XL model")
|
logger.info("Head pruning is not implemented for Transformer-XL model")
|
||||||
@@ -821,13 +807,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
|||||||
assert len(hids) == len(mems), "len(hids) != len(mems)"
|
assert len(hids) == len(mems), "len(hids) != len(mems)"
|
||||||
|
|
||||||
# There are `mlen + qlen` steps that can be cached into mems
|
# There are `mlen + qlen` steps that can be cached into mems
|
||||||
# For the next step, the last `ext_len` of the `qlen` tokens
|
|
||||||
# will be used as the extended context. Hence, we only cache
|
|
||||||
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
|
||||||
# to `mlen + qlen - self.ext_len`.
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
new_mems = []
|
new_mems = []
|
||||||
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
end_idx = mlen + max(0, qlen)
|
||||||
beg_idx = max(0, end_idx - self.mem_len)
|
beg_idx = max(0, end_idx - self.mem_len)
|
||||||
for i in range(len(hids)):
|
for i in range(len(hids)):
|
||||||
|
|
||||||
@@ -1010,7 +992,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
|||||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
||||||
|
|
||||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
warnings.warn(
|
||||||
|
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
|
||||||
|
FutureWarning,
|
||||||
|
)
|
||||||
|
self.transformer.reset_memory_length(mem_len)
|
||||||
|
|
||||||
|
def reset_memory_length(self, mem_len):
|
||||||
|
self.transformer.reset_memory_length(mem_len)
|
||||||
|
|
||||||
def init_mems(self, bsz):
|
def init_mems(self, bsz):
|
||||||
return self.transformer.init_mems(bsz)
|
return self.transformer.init_mems(bsz)
|
||||||
|
|||||||
Reference in New Issue
Block a user