Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)
* first commit * work in progress * make language generation task pass * update to working version for LM * delete print * remove dead code * make style
This commit is contained in:
committed by
GitHub
parent
efdb46b6e2
commit
292186a3e7
@@ -30,7 +30,7 @@ if is_tf_available():
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
|
||||
from transformers import tf_top_k_top_p_filtering
|
||||
from transformers import tf_top_k_top_p_filtering, TFAdaptiveEmbedding
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
@@ -348,7 +348,7 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
assert isinstance(model.get_input_embeddings(), (tf.keras.layers.Layer, TFAdaptiveEmbedding))
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None or isinstance(x, tf.keras.layers.Layer)
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_transfo_xl import (
|
||||
from transformers import (
|
||||
TFTransfoXLModel,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@@ -364,7 +364,7 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
0,
|
||||
]
|
||||
],
|
||||
dtype=tf.int31,
|
||||
dtype=tf.int32,
|
||||
)
|
||||
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
||||
# ( except for Alexei and Maria ) are discovered .
|
||||
@@ -570,8 +570,5 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
# Nicholas II and his family were discovered. The voice of <unk> young son,
|
||||
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
|
||||
|
||||
# TODO: add this test when trasnfo-xl-lmhead is implemented
|
||||
with self.assertRaises(NotImplementedError):
|
||||
model.generate(input_ids, max_length=200, do_sample=False)
|
||||
print(expected_output_ids)
|
||||
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -129,10 +129,10 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def check_transfo_xl_model_output(self, result):
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["hidden_states_1"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["hidden_states_2"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -166,7 +166,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def check_transfo_xl_lm_head_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -175,7 +175,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
|
||||
Reference in New Issue
Block a user