updated all tests

This commit is contained in:
patrickvonplaten
2020-03-08 15:29:10 +01:00
parent e03129ad44
commit 575976144a
13 changed files with 1465 additions and 112 deletions

View File

@@ -238,3 +238,35 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
class TFOPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_openai_gpt(self):
model = TFOpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
input_ids = tf.convert_to_tensor([[481, 4735, 544]], dtype=tf.int32) # the president is
expected_output_ids = [
481,
4735,
544,
246,
963,
870,
762,
239,
244,
40477,
244,
249,
719,
881,
487,
544,
240,
244,
603,
481,
] # the president is a very good man. " \n " i\'m sure he is, " said the
output_ids = model.generate(input_ids, do_sample=False)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)