Add slow generate tests for pretrained lm models (#2909)

* add slow generate lm_model tests

* fix conflicts

* merge conflicts

* fix conflicts

* add slow generate lm_model tests

* make style

* delete unused variable

* fix conflicts

* fix conflicts

* fix conflicts

* delete unused variable

* fix conflicts

* finished hard coded tests
This commit is contained in:
Patrick von Platen
2020-02-24 17:51:57 +01:00
committed by GitHub
parent 8194df8e0c
commit 17c45c39ed
8 changed files with 991 additions and 6 deletions

View File

@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
if is_torch_available():
import torch
from transformers import (
OpenAIGPTConfig,
OpenAIGPTModel,
@@ -208,3 +209,36 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
self.assertIsNotNone(model)
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_openai_gpt(self):
model = OpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
input_ids = torch.Tensor([[481, 2585, 544, 4957]]).long() # The dog is cute
expected_output_ids = [
481,
2585,
544,
4957,
669,
512,
761,
5990,
271,
645,
487,
535,
976,
2479,
240,
487,
804,
1296,
2891,
512,
] # the dog is cute when you're annoyed : if he's really stupid, he 'll stop fighting you
torch.manual_seed(0)
output_ids = model.generate(input_ids)
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)