Adding LM Head to Transfo-XL and first step to fixing problem with Adaptive Embeddings in TransfoXL (#3286)
* first commit * work in progress * make language generation task pass * update to working version for LM * delete print * remove dead code * make style
This commit is contained in:
committed by
GitHub
parent
efdb46b6e2
commit
292186a3e7
@@ -26,7 +26,7 @@ from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_transfo_xl import (
|
||||
from transformers import (
|
||||
TFTransfoXLModel,
|
||||
TFTransfoXLLMHeadModel,
|
||||
TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
@@ -364,7 +364,7 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
0,
|
||||
]
|
||||
],
|
||||
dtype=tf.int31,
|
||||
dtype=tf.int32,
|
||||
)
|
||||
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
||||
# ( except for Alexei and Maria ) are discovered .
|
||||
@@ -570,8 +570,5 @@ class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
# Nicholas II and his family were discovered. The voice of <unk> young son,
|
||||
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
|
||||
|
||||
# TODO: add this test when trasnfo-xl-lmhead is implemented
|
||||
with self.assertRaises(NotImplementedError):
|
||||
model.generate(input_ids, max_length=200, do_sample=False)
|
||||
print(expected_output_ids)
|
||||
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user