Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)

* first commit

* work in progress

* make language generation task pass

* update to working version for LM

* delete print

* remove dead code

* make style
This commit is contained in:
Patrick von Platen
2020-03-18 14:24:27 +01:00
committed by GitHub
parent efdb46b6e2
commit 292186a3e7
7 changed files with 79 additions and 147 deletions

View File

@@ -26,7 +26,7 @@ from .utils import CACHE_DIR, require_tf, slow
if is_tf_available():
import tensorflow as tf
from transformers.modeling_tf_transfo_xl import (
from transformers import (
TFTransfoXLModel,
TFTransfoXLLMHeadModel,
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
@@ -364,7 +364,7 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
0,
]
],
dtype=tf.int31,
dtype=tf.int32,
)
# In 1991 , the remains of Russian Tsar Nicholas II and his family
# ( except for Alexei and Maria ) are discovered .
@@ -570,8 +570,5 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
# Nicholas II and his family were discovered. The voice of <unk> young son,
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
# TODO: add this test when trasnfo-xl-lmhead is implemented
with self.assertRaises(NotImplementedError):
model.generate(input_ids, max_length=200, do_sample=False)
print(expected_output_ids)
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)