From 709745927b119a0414a167a10e363d9a9ef1ef38 Mon Sep 17 00:00:00 2001 From: RafaelWO <38643099+RafaelWO@users.noreply.github.com> Date: Thu, 17 Sep 2020 12:10:34 +0200 Subject: [PATCH] Transformer-XL: Remove unused parameters (#7087) * Removed 'tgt_len' and 'ext_len' from Transfomer-XL * Some changes are still to be done * Removed 'tgt_len' and 'ext_len' from Transfomer-XL (2) * Removed comments * Fixed quality * Changed warning to info --- examples/contrib/run_transfo_xl.py | 2 +- src/transformers/configuration_transfo_xl.py | 12 ++---- src/transformers/modeling_tf_transfo_xl.py | 41 ++++++-------------- src/transformers/modeling_transfo_xl.py | 33 ++++++---------- 4 files changed, 27 insertions(+), 61 deletions(-) diff --git a/examples/contrib/run_transfo_xl.py b/examples/contrib/run_transfo_xl.py index a28637c596..db3375a20a 100644 --- a/examples/contrib/run_transfo_xl.py +++ b/examples/contrib/run_transfo_xl.py @@ -88,7 +88,7 @@ def main(): ) ) - model.reset_length(args.tgt_len, args.ext_len, args.mem_len) + model.reset_memory_length(args.mem_len) if args.clamp_len > 0: model.clamp_len = args.clamp_len if args.same_length: diff --git a/src/transformers/configuration_transfo_xl.py b/src/transformers/configuration_transfo_xl.py index 4fbf599fe1..4864ec80ca 100644 --- a/src/transformers/configuration_transfo_xl.py +++ b/src/transformers/configuration_transfo_xl.py @@ -62,10 +62,6 @@ class TransfoXLConfig(PretrainedConfig): Apply LayerNorm to the input instead of the output n_layer (:obj:`int`, optional, defaults to 18): Number of hidden layers in the Transformer encoder. - tgt_len (:obj:`int`, optional, defaults to 128): - Number of tokens to predict - ext_len (:obj:`int`, optional, defaults to 0): - Length of the extended context mem_len (:obj:`int`, optional, defaults to 1600): Length of the retained previous heads clamp_len (:obj:`int`, optional, defaults to 1000): @@ -125,8 +121,6 @@ class TransfoXLConfig(PretrainedConfig): div_val=4, pre_lnorm=False, n_layer=18, - tgt_len=128, - ext_len=0, mem_len=1600, clamp_len=1000, same_length=True, @@ -168,8 +162,6 @@ class TransfoXLConfig(PretrainedConfig): self.pre_lnorm = pre_lnorm self.n_layer = n_layer self.n_head = n_head - self.tgt_len = tgt_len - self.ext_len = ext_len self.mem_len = mem_len self.same_length = same_length self.attn_type = attn_type @@ -187,7 +179,9 @@ class TransfoXLConfig(PretrainedConfig): @property def max_position_embeddings(self): - return self.tgt_len + self.ext_len + self.mem_len + # Message copied from Transformer-XL documentation + logger.info(f"The model {self.model_type} is one of the few models that has no sequence length limit.") + return -1 @property def n_token(self): # Backward compatibility diff --git a/src/transformers/modeling_tf_transfo_xl.py b/src/transformers/modeling_tf_transfo_xl.py index 0b3330518b..68c0e85fb3 100644 --- a/src/transformers/modeling_tf_transfo_xl.py +++ b/src/transformers/modeling_tf_transfo_xl.py @@ -15,8 +15,7 @@ # limitations under the License. """ TF 2.0 Transformer XL model. """ - - +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple @@ -107,10 +106,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): d_model, d_head, dropout, - dropatt=0, - tgt_len=None, - ext_len=None, - mem_len=None, + dropatt=0.0, pre_lnorm=False, r_r_bias=None, r_w_bias=None, @@ -261,9 +257,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): d_head, d_inner, dropout, - tgt_len=None, - ext_len=None, - mem_len=None, dropatt=0.0, pre_lnorm=False, r_w_bias=None, @@ -280,9 +273,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): d_model, d_head, dropout, - tgt_len=tgt_len, - ext_len=ext_len, - mem_len=mem_len, dropatt=dropatt, pre_lnorm=pre_lnorm, r_w_bias=r_w_bias, @@ -414,12 +404,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): self.drop = tf.keras.layers.Dropout(config.dropout) self.n_layer = config.n_layer - - self.tgt_len = config.tgt_len self.mem_len = config.mem_len - self.ext_len = config.ext_len - self.max_klen = config.tgt_len + config.ext_len + config.mem_len - self.attn_type = config.attn_type self.layers = [] @@ -432,9 +417,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): config.d_head, config.d_inner, config.dropout, - tgt_len=config.tgt_len, - ext_len=config.ext_len, - mem_len=config.mem_len, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, r_w_bias=None if self.untie_r else self.r_w_bias, @@ -478,10 +460,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): def backward_compatible(self): self.sample_softmax = -1 - def reset_length(self, tgt_len, ext_len, mem_len): - self.tgt_len = tgt_len + def reset_memory_length(self, mem_len): self.mem_len = mem_len - self.ext_len = ext_len def _prune_heads(self, heads): raise NotImplementedError @@ -506,12 +486,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): assert len(hids) == len(mems), "len(hids) != len(mems)" # There are `mlen + qlen` steps that can be cached into mems - # For the next step, the last `ext_len` of the `qlen` tokens - # will be used as the extended context. Hence, we only cache - # the tokens from `mlen + qlen - self.ext_len - self.mem_len` - # to `mlen + qlen - self.ext_len`. new_mems = [] - end_idx = mlen + max(0, qlen - 0 - self.ext_len) + end_idx = mlen + max(0, qlen) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): @@ -867,7 +843,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): return None def reset_length(self, tgt_len, ext_len, mem_len): - self.transformer.reset_length(tgt_len, ext_len, mem_len) + warnings.warn( + "The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.", + FutureWarning, + ) + self.transformer.reset_memory_length(mem_len) + + def reset_memory_length(self, mem_len): + self.transformer.reset_memory_length(mem_len) def init_mems(self, bsz): return self.transformer.init_mems(bsz) diff --git a/src/transformers/modeling_transfo_xl.py b/src/transformers/modeling_transfo_xl.py index 26cbb59274..83c253be54 100644 --- a/src/transformers/modeling_transfo_xl.py +++ b/src/transformers/modeling_transfo_xl.py @@ -17,8 +17,7 @@ Adapted from https://github.com/kimiyoung/transformer-xl. In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py """ - - +import warnings from dataclasses import dataclass from typing import List, Optional, Tuple @@ -234,9 +233,6 @@ class RelPartialLearnableMultiHeadAttn(nn.Module): d_head, dropout, dropatt=0, - tgt_len=None, - ext_len=None, - mem_len=None, pre_lnorm=False, r_r_bias=None, r_w_bias=None, @@ -737,12 +733,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): self.drop = nn.Dropout(config.dropout) self.n_layer = config.n_layer - - self.tgt_len = config.tgt_len self.mem_len = config.mem_len - self.ext_len = config.ext_len - self.max_klen = config.tgt_len + config.ext_len + config.mem_len - self.attn_type = config.attn_type if not config.untie_r: @@ -759,9 +750,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel): config.d_head, config.d_inner, config.dropout, - tgt_len=config.tgt_len, - ext_len=config.ext_len, - mem_len=config.mem_len, dropatt=config.dropatt, pre_lnorm=config.pre_lnorm, r_w_bias=None if config.untie_r else self.r_w_bias, @@ -791,10 +779,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): def backward_compatible(self): self.sample_softmax = -1 - def reset_length(self, tgt_len, ext_len, mem_len): - self.tgt_len = tgt_len + def reset_memory_length(self, mem_len): self.mem_len = mem_len - self.ext_len = ext_len def _prune_heads(self, heads): logger.info("Head pruning is not implemented for Transformer-XL model") @@ -821,13 +807,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel): assert len(hids) == len(mems), "len(hids) != len(mems)" # There are `mlen + qlen` steps that can be cached into mems - # For the next step, the last `ext_len` of the `qlen` tokens - # will be used as the extended context. Hence, we only cache - # the tokens from `mlen + qlen - self.ext_len - self.mem_len` - # to `mlen + qlen - self.ext_len`. with torch.no_grad(): new_mems = [] - end_idx = mlen + max(0, qlen - 0 - self.ext_len) + end_idx = mlen + max(0, qlen) beg_idx = max(0, end_idx - self.mem_len) for i in range(len(hids)): @@ -1010,7 +992,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] def reset_length(self, tgt_len, ext_len, mem_len): - self.transformer.reset_length(tgt_len, ext_len, mem_len) + warnings.warn( + "The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.", + FutureWarning, + ) + self.transformer.reset_memory_length(mem_len) + + def reset_memory_length(self, mem_len): + self.transformer.reset_memory_length(mem_len) def init_mems(self, bsz): return self.transformer.init_mems(bsz)