updated all tests

This commit is contained in:
patrickvonplaten
2020-03-08 15:29:10 +01:00
parent e03129ad44
commit 575976144a
13 changed files with 1465 additions and 112 deletions

View File

@@ -311,3 +311,34 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFXLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
class TFXLMModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_xlm_mlm_en_2048(self):
model = TFXLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
input_ids = tf.convert_to_tensor([[1, 14, 2232, 26, 1]], dtype=tf.int32) # the dog is cute
expected_output_ids = [
1,
14,
2232,
26,
1,
567,
26,
32,
149,
149,
149,
149,
149,
149,
149,
149,
149,
149,
149,
149,
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids, do_sample=False)