Transformer-XL: Remove unused parameters (#7087)
* Removed 'tgt_len' and 'ext_len' from Transfomer-XL * Some changes are still to be done * Removed 'tgt_len' and 'ext_len' from Transfomer-XL (2) * Removed comments * Fixed quality * Changed warning to info
This commit is contained in:
@@ -88,7 +88,7 @@ def main():
|
||||
)
|
||||
)
|
||||
|
||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||
model.reset_memory_length(args.mem_len)
|
||||
if args.clamp_len > 0:
|
||||
model.clamp_len = args.clamp_len
|
||||
if args.same_length:
|
||||
|
||||
@@ -62,10 +62,6 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
Apply LayerNorm to the input instead of the output
|
||||
n_layer (:obj:`int`, optional, defaults to 18):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
tgt_len (:obj:`int`, optional, defaults to 128):
|
||||
Number of tokens to predict
|
||||
ext_len (:obj:`int`, optional, defaults to 0):
|
||||
Length of the extended context
|
||||
mem_len (:obj:`int`, optional, defaults to 1600):
|
||||
Length of the retained previous heads
|
||||
clamp_len (:obj:`int`, optional, defaults to 1000):
|
||||
@@ -125,8 +121,6 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
div_val=4,
|
||||
pre_lnorm=False,
|
||||
n_layer=18,
|
||||
tgt_len=128,
|
||||
ext_len=0,
|
||||
mem_len=1600,
|
||||
clamp_len=1000,
|
||||
same_length=True,
|
||||
@@ -168,8 +162,6 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
self.pre_lnorm = pre_lnorm
|
||||
self.n_layer = n_layer
|
||||
self.n_head = n_head
|
||||
self.tgt_len = tgt_len
|
||||
self.ext_len = ext_len
|
||||
self.mem_len = mem_len
|
||||
self.same_length = same_length
|
||||
self.attn_type = attn_type
|
||||
@@ -187,7 +179,9 @@ class TransfoXLConfig(PretrainedConfig):
|
||||
|
||||
@property
|
||||
def max_position_embeddings(self):
|
||||
return self.tgt_len + self.ext_len + self.mem_len
|
||||
# Message copied from Transformer-XL documentation
|
||||
logger.info(f"The model {self.model_type} is one of the few models that has no sequence length limit.")
|
||||
return -1
|
||||
|
||||
@property
|
||||
def n_token(self): # Backward compatibility
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
# limitations under the License.
|
||||
""" TF 2.0 Transformer XL model.
|
||||
"""
|
||||
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -107,10 +106,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
||||
d_model,
|
||||
d_head,
|
||||
dropout,
|
||||
dropatt=0,
|
||||
tgt_len=None,
|
||||
ext_len=None,
|
||||
mem_len=None,
|
||||
dropatt=0.0,
|
||||
pre_lnorm=False,
|
||||
r_r_bias=None,
|
||||
r_w_bias=None,
|
||||
@@ -261,9 +257,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
||||
d_head,
|
||||
d_inner,
|
||||
dropout,
|
||||
tgt_len=None,
|
||||
ext_len=None,
|
||||
mem_len=None,
|
||||
dropatt=0.0,
|
||||
pre_lnorm=False,
|
||||
r_w_bias=None,
|
||||
@@ -280,9 +273,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
||||
d_model,
|
||||
d_head,
|
||||
dropout,
|
||||
tgt_len=tgt_len,
|
||||
ext_len=ext_len,
|
||||
mem_len=mem_len,
|
||||
dropatt=dropatt,
|
||||
pre_lnorm=pre_lnorm,
|
||||
r_w_bias=r_w_bias,
|
||||
@@ -414,12 +404,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
self.drop = tf.keras.layers.Dropout(config.dropout)
|
||||
|
||||
self.n_layer = config.n_layer
|
||||
|
||||
self.tgt_len = config.tgt_len
|
||||
self.mem_len = config.mem_len
|
||||
self.ext_len = config.ext_len
|
||||
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
|
||||
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
self.layers = []
|
||||
@@ -432,9 +417,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
config.d_head,
|
||||
config.d_inner,
|
||||
config.dropout,
|
||||
tgt_len=config.tgt_len,
|
||||
ext_len=config.ext_len,
|
||||
mem_len=config.mem_len,
|
||||
dropatt=config.dropatt,
|
||||
pre_lnorm=config.pre_lnorm,
|
||||
r_w_bias=None if self.untie_r else self.r_w_bias,
|
||||
@@ -478,10 +460,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
def backward_compatible(self):
|
||||
self.sample_softmax = -1
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.tgt_len = tgt_len
|
||||
def reset_memory_length(self, mem_len):
|
||||
self.mem_len = mem_len
|
||||
self.ext_len = ext_len
|
||||
|
||||
def _prune_heads(self, heads):
|
||||
raise NotImplementedError
|
||||
@@ -506,12 +486,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
||||
assert len(hids) == len(mems), "len(hids) != len(mems)"
|
||||
|
||||
# There are `mlen + qlen` steps that can be cached into mems
|
||||
# For the next step, the last `ext_len` of the `qlen` tokens
|
||||
# will be used as the extended context. Hence, we only cache
|
||||
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
||||
# to `mlen + qlen - self.ext_len`.
|
||||
new_mems = []
|
||||
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
||||
end_idx = mlen + max(0, qlen)
|
||||
beg_idx = max(0, end_idx - self.mem_len)
|
||||
for i in range(len(hids)):
|
||||
|
||||
@@ -867,7 +843,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
||||
return None
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||
warnings.warn(
|
||||
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.transformer.reset_memory_length(mem_len)
|
||||
|
||||
def reset_memory_length(self, mem_len):
|
||||
self.transformer.reset_memory_length(mem_len)
|
||||
|
||||
def init_mems(self, bsz):
|
||||
return self.transformer.init_mems(bsz)
|
||||
|
||||
@@ -17,8 +17,7 @@
|
||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
|
||||
"""
|
||||
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@@ -234,9 +233,6 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
|
||||
d_head,
|
||||
dropout,
|
||||
dropatt=0,
|
||||
tgt_len=None,
|
||||
ext_len=None,
|
||||
mem_len=None,
|
||||
pre_lnorm=False,
|
||||
r_r_bias=None,
|
||||
r_w_bias=None,
|
||||
@@ -737,12 +733,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
self.drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.n_layer = config.n_layer
|
||||
|
||||
self.tgt_len = config.tgt_len
|
||||
self.mem_len = config.mem_len
|
||||
self.ext_len = config.ext_len
|
||||
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
|
||||
|
||||
self.attn_type = config.attn_type
|
||||
|
||||
if not config.untie_r:
|
||||
@@ -759,9 +750,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
config.d_head,
|
||||
config.d_inner,
|
||||
config.dropout,
|
||||
tgt_len=config.tgt_len,
|
||||
ext_len=config.ext_len,
|
||||
mem_len=config.mem_len,
|
||||
dropatt=config.dropatt,
|
||||
pre_lnorm=config.pre_lnorm,
|
||||
r_w_bias=None if config.untie_r else self.r_w_bias,
|
||||
@@ -791,10 +779,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
def backward_compatible(self):
|
||||
self.sample_softmax = -1
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.tgt_len = tgt_len
|
||||
def reset_memory_length(self, mem_len):
|
||||
self.mem_len = mem_len
|
||||
self.ext_len = ext_len
|
||||
|
||||
def _prune_heads(self, heads):
|
||||
logger.info("Head pruning is not implemented for Transformer-XL model")
|
||||
@@ -821,13 +807,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
assert len(hids) == len(mems), "len(hids) != len(mems)"
|
||||
|
||||
# There are `mlen + qlen` steps that can be cached into mems
|
||||
# For the next step, the last `ext_len` of the `qlen` tokens
|
||||
# will be used as the extended context. Hence, we only cache
|
||||
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
|
||||
# to `mlen + qlen - self.ext_len`.
|
||||
with torch.no_grad():
|
||||
new_mems = []
|
||||
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
|
||||
end_idx = mlen + max(0, qlen)
|
||||
beg_idx = max(0, end_idx - self.mem_len)
|
||||
for i in range(len(hids)):
|
||||
|
||||
@@ -1010,7 +992,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
|
||||
|
||||
def reset_length(self, tgt_len, ext_len, mem_len):
|
||||
self.transformer.reset_length(tgt_len, ext_len, mem_len)
|
||||
warnings.warn(
|
||||
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
self.transformer.reset_memory_length(mem_len)
|
||||
|
||||
def reset_memory_length(self, mem_len):
|
||||
self.transformer.reset_memory_length(mem_len)
|
||||
|
||||
def init_mems(self, bsz):
|
||||
return self.transformer.init_mems(bsz)
|
||||
|
||||
Reference in New Issue
Block a user