Add slow generate tests for pretrained lm models (#2909)
* add slow generate lm_model tests * fix conflicts * merge conflicts * fix conflicts * add slow generate lm_model tests * make style * delete unused variable * fix conflicts * fix conflicts * fix conflicts * delete unused variable * fix conflicts * finished hard coded tests
This commit is contained in:
committed by
GitHub
parent
8194df8e0c
commit
17c45c39ed
@@ -512,7 +512,7 @@ class XLMModel(XLMPreTrainedModel):
|
|||||||
inputs_embeds = self.embeddings(input_ids)
|
inputs_embeds = self.embeddings(input_ids)
|
||||||
|
|
||||||
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
|
tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
|
||||||
if langs is not None and self.use_lang_emb:
|
if langs is not None and self.use_lang_emb and self.n_langs > 1:
|
||||||
tensor = tensor + self.lang_embeddings(langs)
|
tensor = tensor + self.lang_embeddings(langs)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
tensor = tensor + self.embeddings(token_type_ids)
|
tensor = tensor + self.embeddings(token_type_ids)
|
||||||
|
|||||||
@@ -641,7 +641,7 @@ global_rng = random.Random()
|
|||||||
|
|
||||||
|
|
||||||
def ids_tensor(shape, vocab_size, rng=None, name=None):
|
def ids_tensor(shape, vocab_size, rng=None, name=None):
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
# Creates a random int32 tensor of the shape within the vocab size
|
||||||
if rng is None:
|
if rng is None:
|
||||||
rng = global_rng
|
rng = global_rng
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
from transformers import CTRLConfig, CTRLModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel
|
from transformers import CTRLConfig, CTRLModel, CTRL_PRETRAINED_MODEL_ARCHIVE_MAP, CTRLLMHeadModel
|
||||||
|
|
||||||
|
|
||||||
@@ -212,3 +213,36 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = CTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
class CTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_ctrl(self):
|
||||||
|
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||||
|
input_ids = torch.Tensor([[11859, 586, 20984, 8]]).long() # Legal My neighbor is
|
||||||
|
expected_output_ids = [
|
||||||
|
11859,
|
||||||
|
586,
|
||||||
|
20984,
|
||||||
|
8,
|
||||||
|
13391,
|
||||||
|
3,
|
||||||
|
980,
|
||||||
|
8258,
|
||||||
|
72,
|
||||||
|
327,
|
||||||
|
148,
|
||||||
|
2,
|
||||||
|
53,
|
||||||
|
29,
|
||||||
|
226,
|
||||||
|
3,
|
||||||
|
780,
|
||||||
|
49,
|
||||||
|
3,
|
||||||
|
980,
|
||||||
|
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
output_ids = model.generate(input_ids)
|
||||||
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
GPT2Config,
|
GPT2Config,
|
||||||
GPT2Model,
|
GPT2Model,
|
||||||
@@ -165,7 +166,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
"presents": presents,
|
"presents": presents,
|
||||||
}
|
}
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||||
)
|
)
|
||||||
self.parent.assertEqual(len(result["presents"]), config.n_layer)
|
self.parent.assertEqual(len(result["presents"]), config.n_layer)
|
||||||
|
|
||||||
@@ -180,7 +181,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_and_check_double_lm_head_model(
|
def create_and_check_double_lm_head_model(
|
||||||
@@ -208,7 +209,8 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
|
|
||||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||||
self.parent.assertListEqual(
|
self.parent.assertListEqual(
|
||||||
list(result["lm_logits"].size()), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]
|
list(result["lm_logits"].size()),
|
||||||
|
[self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
|
||||||
)
|
)
|
||||||
self.parent.assertListEqual(list(result["mc_logits"].size()), [self.batch_size, self.num_choices])
|
self.parent.assertListEqual(list(result["mc_logits"].size()), [self.batch_size, self.num_choices])
|
||||||
|
|
||||||
@@ -227,7 +229,11 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
choice_labels,
|
choice_labels,
|
||||||
) = config_and_inputs
|
) = config_and_inputs
|
||||||
|
|
||||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "head_mask": head_mask}
|
inputs_dict = {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"token_type_ids": token_type_ids,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
}
|
||||||
|
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
@@ -255,3 +261,84 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = GPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_generation_special_tokens():
|
||||||
|
return {"bos_token_id": 50256, "eos_token_id": 50256}
|
||||||
|
|
||||||
|
|
||||||
|
class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
special_tokens = prepare_generation_special_tokens()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_gpt2(self):
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||||
|
input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute
|
||||||
|
expected_output_ids = [
|
||||||
|
464,
|
||||||
|
3290,
|
||||||
|
318,
|
||||||
|
13779,
|
||||||
|
1165,
|
||||||
|
13,
|
||||||
|
632,
|
||||||
|
7832,
|
||||||
|
284,
|
||||||
|
6437,
|
||||||
|
319,
|
||||||
|
502,
|
||||||
|
290,
|
||||||
|
318,
|
||||||
|
922,
|
||||||
|
329,
|
||||||
|
502,
|
||||||
|
357,
|
||||||
|
1169,
|
||||||
|
3290,
|
||||||
|
] # The dog is cute too. It likes to rub on me and is good for me (the dog
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
output_ids = model.generate(
|
||||||
|
input_ids,
|
||||||
|
bos_token_id=self.special_tokens["bos_token_id"],
|
||||||
|
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_distilgpt2(self):
|
||||||
|
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||||
|
input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute
|
||||||
|
expected_output_ids = [
|
||||||
|
464,
|
||||||
|
3290,
|
||||||
|
318,
|
||||||
|
13779,
|
||||||
|
996,
|
||||||
|
339,
|
||||||
|
460,
|
||||||
|
3360,
|
||||||
|
655,
|
||||||
|
2513,
|
||||||
|
287,
|
||||||
|
262,
|
||||||
|
3952,
|
||||||
|
13,
|
||||||
|
632,
|
||||||
|
318,
|
||||||
|
407,
|
||||||
|
845,
|
||||||
|
3621,
|
||||||
|
284,
|
||||||
|
] # The dog is cute though he can sometimes just walk in the park. It is not very nice to
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
output_ids = model.generate(
|
||||||
|
input_ids,
|
||||||
|
bos_token_id=self.special_tokens["bos_token_id"],
|
||||||
|
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
OpenAIGPTConfig,
|
OpenAIGPTConfig,
|
||||||
OpenAIGPTModel,
|
OpenAIGPTModel,
|
||||||
@@ -208,3 +209,36 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = OpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_openai_gpt(self):
|
||||||
|
model = OpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
|
||||||
|
input_ids = torch.Tensor([[481, 2585, 544, 4957]]).long() # The dog is cute
|
||||||
|
expected_output_ids = [
|
||||||
|
481,
|
||||||
|
2585,
|
||||||
|
544,
|
||||||
|
4957,
|
||||||
|
669,
|
||||||
|
512,
|
||||||
|
761,
|
||||||
|
5990,
|
||||||
|
271,
|
||||||
|
645,
|
||||||
|
487,
|
||||||
|
535,
|
||||||
|
976,
|
||||||
|
2479,
|
||||||
|
240,
|
||||||
|
487,
|
||||||
|
804,
|
||||||
|
1296,
|
||||||
|
2891,
|
||||||
|
512,
|
||||||
|
] # the dog is cute when you're annoyed : if he's really stupid, he 'll stop fighting you
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
output_ids = model.generate(input_ids)
|
||||||
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -212,3 +212,372 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = TransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_generation_special_tokens():
|
||||||
|
return {"eos_token_id": 0}
|
||||||
|
|
||||||
|
|
||||||
|
class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
special_tokens = prepare_generation_special_tokens()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_transfo_xl_wt103(self):
|
||||||
|
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
|
||||||
|
input_ids = torch.Tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
33,
|
||||||
|
1297,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
1009,
|
||||||
|
4,
|
||||||
|
1109,
|
||||||
|
11739,
|
||||||
|
4762,
|
||||||
|
358,
|
||||||
|
5,
|
||||||
|
25,
|
||||||
|
245,
|
||||||
|
22,
|
||||||
|
1706,
|
||||||
|
17,
|
||||||
|
20098,
|
||||||
|
5,
|
||||||
|
3215,
|
||||||
|
21,
|
||||||
|
37,
|
||||||
|
1110,
|
||||||
|
3,
|
||||||
|
13,
|
||||||
|
1041,
|
||||||
|
4,
|
||||||
|
24,
|
||||||
|
603,
|
||||||
|
490,
|
||||||
|
2,
|
||||||
|
71477,
|
||||||
|
20098,
|
||||||
|
104447,
|
||||||
|
2,
|
||||||
|
20961,
|
||||||
|
1,
|
||||||
|
2604,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
329,
|
||||||
|
3,
|
||||||
|
6224,
|
||||||
|
831,
|
||||||
|
16002,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
603,
|
||||||
|
78967,
|
||||||
|
29546,
|
||||||
|
23,
|
||||||
|
803,
|
||||||
|
20,
|
||||||
|
25,
|
||||||
|
416,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
232,
|
||||||
|
4,
|
||||||
|
277,
|
||||||
|
6,
|
||||||
|
1855,
|
||||||
|
4601,
|
||||||
|
3,
|
||||||
|
29546,
|
||||||
|
54,
|
||||||
|
8,
|
||||||
|
3609,
|
||||||
|
5,
|
||||||
|
57211,
|
||||||
|
49,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
277,
|
||||||
|
18,
|
||||||
|
8,
|
||||||
|
1755,
|
||||||
|
15691,
|
||||||
|
3,
|
||||||
|
341,
|
||||||
|
25,
|
||||||
|
416,
|
||||||
|
693,
|
||||||
|
42573,
|
||||||
|
71,
|
||||||
|
17,
|
||||||
|
401,
|
||||||
|
94,
|
||||||
|
31,
|
||||||
|
17919,
|
||||||
|
2,
|
||||||
|
29546,
|
||||||
|
7873,
|
||||||
|
18,
|
||||||
|
1,
|
||||||
|
435,
|
||||||
|
23,
|
||||||
|
11011,
|
||||||
|
755,
|
||||||
|
5,
|
||||||
|
5167,
|
||||||
|
3,
|
||||||
|
7983,
|
||||||
|
98,
|
||||||
|
84,
|
||||||
|
2,
|
||||||
|
29546,
|
||||||
|
3267,
|
||||||
|
8,
|
||||||
|
3609,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
4865,
|
||||||
|
1075,
|
||||||
|
2,
|
||||||
|
6087,
|
||||||
|
71,
|
||||||
|
6,
|
||||||
|
346,
|
||||||
|
8,
|
||||||
|
5854,
|
||||||
|
3,
|
||||||
|
29546,
|
||||||
|
824,
|
||||||
|
1400,
|
||||||
|
1868,
|
||||||
|
2,
|
||||||
|
19,
|
||||||
|
160,
|
||||||
|
2,
|
||||||
|
311,
|
||||||
|
8,
|
||||||
|
5496,
|
||||||
|
2,
|
||||||
|
20920,
|
||||||
|
17,
|
||||||
|
25,
|
||||||
|
15097,
|
||||||
|
3,
|
||||||
|
24,
|
||||||
|
24,
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
).long()
|
||||||
|
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
||||||
|
# ( except for Alexei and Maria ) are discovered .
|
||||||
|
# The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
|
||||||
|
# remainder of the story . 1883 Western Siberia ,
|
||||||
|
# a young Grigori Rasputin is asked by his father and a group of men to perform magic .
|
||||||
|
# Rasputin has a vision and denounces one of the men as a horse thief . Although his
|
||||||
|
# father initially slaps him for making such an accusation , Rasputin watches as the
|
||||||
|
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
|
||||||
|
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
|
||||||
|
# with people , even a bishop , begging for his blessing . <eod> </s> <eos>
|
||||||
|
|
||||||
|
expected_output_ids = [
|
||||||
|
33,
|
||||||
|
1297,
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
1009,
|
||||||
|
4,
|
||||||
|
1109,
|
||||||
|
11739,
|
||||||
|
4762,
|
||||||
|
358,
|
||||||
|
5,
|
||||||
|
25,
|
||||||
|
245,
|
||||||
|
22,
|
||||||
|
1706,
|
||||||
|
17,
|
||||||
|
20098,
|
||||||
|
5,
|
||||||
|
3215,
|
||||||
|
21,
|
||||||
|
37,
|
||||||
|
1110,
|
||||||
|
3,
|
||||||
|
13,
|
||||||
|
1041,
|
||||||
|
4,
|
||||||
|
24,
|
||||||
|
603,
|
||||||
|
490,
|
||||||
|
2,
|
||||||
|
71477,
|
||||||
|
20098,
|
||||||
|
104447,
|
||||||
|
2,
|
||||||
|
20961,
|
||||||
|
1,
|
||||||
|
2604,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
329,
|
||||||
|
3,
|
||||||
|
6224,
|
||||||
|
831,
|
||||||
|
16002,
|
||||||
|
2,
|
||||||
|
8,
|
||||||
|
603,
|
||||||
|
78967,
|
||||||
|
29546,
|
||||||
|
23,
|
||||||
|
803,
|
||||||
|
20,
|
||||||
|
25,
|
||||||
|
416,
|
||||||
|
5,
|
||||||
|
8,
|
||||||
|
232,
|
||||||
|
4,
|
||||||
|
277,
|
||||||
|
6,
|
||||||
|
1855,
|
||||||
|
4601,
|
||||||
|
3,
|
||||||
|
29546,
|
||||||
|
54,
|
||||||
|
8,
|
||||||
|
3609,
|
||||||
|
5,
|
||||||
|
57211,
|
||||||
|
49,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
277,
|
||||||
|
18,
|
||||||
|
8,
|
||||||
|
1755,
|
||||||
|
15691,
|
||||||
|
3,
|
||||||
|
341,
|
||||||
|
25,
|
||||||
|
416,
|
||||||
|
693,
|
||||||
|
42573,
|
||||||
|
71,
|
||||||
|
17,
|
||||||
|
401,
|
||||||
|
94,
|
||||||
|
31,
|
||||||
|
17919,
|
||||||
|
2,
|
||||||
|
29546,
|
||||||
|
7873,
|
||||||
|
18,
|
||||||
|
1,
|
||||||
|
435,
|
||||||
|
23,
|
||||||
|
11011,
|
||||||
|
755,
|
||||||
|
5,
|
||||||
|
5167,
|
||||||
|
3,
|
||||||
|
7983,
|
||||||
|
98,
|
||||||
|
84,
|
||||||
|
2,
|
||||||
|
29546,
|
||||||
|
3267,
|
||||||
|
8,
|
||||||
|
3609,
|
||||||
|
4,
|
||||||
|
1,
|
||||||
|
4865,
|
||||||
|
1075,
|
||||||
|
2,
|
||||||
|
6087,
|
||||||
|
71,
|
||||||
|
6,
|
||||||
|
346,
|
||||||
|
8,
|
||||||
|
5854,
|
||||||
|
3,
|
||||||
|
29546,
|
||||||
|
824,
|
||||||
|
1400,
|
||||||
|
1868,
|
||||||
|
2,
|
||||||
|
19,
|
||||||
|
160,
|
||||||
|
2,
|
||||||
|
311,
|
||||||
|
8,
|
||||||
|
5496,
|
||||||
|
2,
|
||||||
|
20920,
|
||||||
|
17,
|
||||||
|
25,
|
||||||
|
15097,
|
||||||
|
3,
|
||||||
|
24,
|
||||||
|
24,
|
||||||
|
0,
|
||||||
|
29546,
|
||||||
|
40,
|
||||||
|
1092,
|
||||||
|
18,
|
||||||
|
8,
|
||||||
|
5854,
|
||||||
|
7,
|
||||||
|
1143,
|
||||||
|
2,
|
||||||
|
7,
|
||||||
|
1,
|
||||||
|
159,
|
||||||
|
99,
|
||||||
|
16,
|
||||||
|
1,
|
||||||
|
1009,
|
||||||
|
4,
|
||||||
|
1109,
|
||||||
|
11739,
|
||||||
|
4762,
|
||||||
|
358,
|
||||||
|
5,
|
||||||
|
25,
|
||||||
|
245,
|
||||||
|
28,
|
||||||
|
1110,
|
||||||
|
3,
|
||||||
|
57,
|
||||||
|
629,
|
||||||
|
38,
|
||||||
|
3493,
|
||||||
|
47,
|
||||||
|
1094,
|
||||||
|
7,
|
||||||
|
1297,
|
||||||
|
3,
|
||||||
|
0,
|
||||||
|
]
|
||||||
|
# In 1991, the remains of Russian Tsar Nicholas II and his family (
|
||||||
|
# except for Alexei and Maria ) are discovered. The voice of young son,
|
||||||
|
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.
|
||||||
|
# 1883 Western Siberia, a young Grigori Rasputin is asked by his father
|
||||||
|
# and a group of men to perform magic. Rasputin has a vision and
|
||||||
|
# denounces one of the men as a horse thief. Although his father initially
|
||||||
|
# slaps him for making such an accusation, Rasputin watches as the man
|
||||||
|
# is chased outside and beaten. Twenty years later, Rasputin sees a vision
|
||||||
|
# of the Virgin Mary, prompting him to become a priest.
|
||||||
|
# Rasputin quickly becomes famous, with people, even a bishop, begging for
|
||||||
|
# his blessing. Rasputin first appears as a priest in 1996, in the same year
|
||||||
|
# that the remains of Russian Tsar Nicholas II and his family were discovered. H
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
output_ids = model.generate(input_ids, eos_token_ids=self.special_tokens["eos_token_id"], max_length=200)
|
||||||
|
|
||||||
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_torch, slow, torch_device
|
|||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
from transformers import (
|
from transformers import (
|
||||||
XLMConfig,
|
XLMConfig,
|
||||||
XLMModel,
|
XLMModel,
|
||||||
@@ -396,3 +397,48 @@ class XLMModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = XLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_generation_special_tokens():
|
||||||
|
return {"bos_token_id": 0, "pad_token_id": 2}
|
||||||
|
|
||||||
|
|
||||||
|
class XLMModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
special_tokens = prepare_generation_special_tokens()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_xlm_mlm_en_2048(self):
|
||||||
|
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
|
||||||
|
input_ids = torch.Tensor([[1, 14, 2232, 26, 1]]).long() # The dog is cute
|
||||||
|
expected_output_ids = [
|
||||||
|
1,
|
||||||
|
14,
|
||||||
|
2232,
|
||||||
|
26,
|
||||||
|
1,
|
||||||
|
567,
|
||||||
|
26,
|
||||||
|
32,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
149,
|
||||||
|
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
|
||||||
|
torch.manual_seed(0)
|
||||||
|
|
||||||
|
output_ids = model.generate(
|
||||||
|
input_ids,
|
||||||
|
bos_token_id=self.special_tokens["bos_token_id"],
|
||||||
|
pad_token_id=self.special_tokens["pad_token_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
@@ -511,3 +511,418 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
model = XLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_generation_special_tokens():
|
||||||
|
return {"bos_token_id": 1, "pad_token_id": 5, "eos_token_id": 2}
|
||||||
|
|
||||||
|
|
||||||
|
class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||||
|
|
||||||
|
special_tokens = prepare_generation_special_tokens()
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_lm_generate_xlnet_base_cased(self):
|
||||||
|
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
|
||||||
|
input_ids = torch.Tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
67,
|
||||||
|
2840,
|
||||||
|
19,
|
||||||
|
18,
|
||||||
|
1484,
|
||||||
|
20,
|
||||||
|
965,
|
||||||
|
29077,
|
||||||
|
8719,
|
||||||
|
1273,
|
||||||
|
21,
|
||||||
|
45,
|
||||||
|
273,
|
||||||
|
17,
|
||||||
|
10,
|
||||||
|
15048,
|
||||||
|
28,
|
||||||
|
27511,
|
||||||
|
21,
|
||||||
|
4185,
|
||||||
|
11,
|
||||||
|
41,
|
||||||
|
2444,
|
||||||
|
9,
|
||||||
|
32,
|
||||||
|
1025,
|
||||||
|
20,
|
||||||
|
8719,
|
||||||
|
26,
|
||||||
|
23,
|
||||||
|
673,
|
||||||
|
966,
|
||||||
|
19,
|
||||||
|
29077,
|
||||||
|
20643,
|
||||||
|
27511,
|
||||||
|
20822,
|
||||||
|
20643,
|
||||||
|
19,
|
||||||
|
17,
|
||||||
|
6616,
|
||||||
|
17511,
|
||||||
|
18,
|
||||||
|
8978,
|
||||||
|
20,
|
||||||
|
18,
|
||||||
|
777,
|
||||||
|
9,
|
||||||
|
19233,
|
||||||
|
1527,
|
||||||
|
17669,
|
||||||
|
19,
|
||||||
|
24,
|
||||||
|
673,
|
||||||
|
17,
|
||||||
|
28756,
|
||||||
|
150,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
27,
|
||||||
|
442,
|
||||||
|
37,
|
||||||
|
45,
|
||||||
|
668,
|
||||||
|
21,
|
||||||
|
24,
|
||||||
|
256,
|
||||||
|
20,
|
||||||
|
416,
|
||||||
|
22,
|
||||||
|
2771,
|
||||||
|
4901,
|
||||||
|
9,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
51,
|
||||||
|
24,
|
||||||
|
3004,
|
||||||
|
21,
|
||||||
|
28142,
|
||||||
|
23,
|
||||||
|
65,
|
||||||
|
20,
|
||||||
|
18,
|
||||||
|
416,
|
||||||
|
34,
|
||||||
|
24,
|
||||||
|
2958,
|
||||||
|
22947,
|
||||||
|
9,
|
||||||
|
1177,
|
||||||
|
45,
|
||||||
|
668,
|
||||||
|
3097,
|
||||||
|
13768,
|
||||||
|
23,
|
||||||
|
103,
|
||||||
|
28,
|
||||||
|
441,
|
||||||
|
148,
|
||||||
|
48,
|
||||||
|
20522,
|
||||||
|
19,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
12860,
|
||||||
|
34,
|
||||||
|
18,
|
||||||
|
326,
|
||||||
|
27,
|
||||||
|
17492,
|
||||||
|
684,
|
||||||
|
21,
|
||||||
|
6709,
|
||||||
|
9,
|
||||||
|
8585,
|
||||||
|
123,
|
||||||
|
266,
|
||||||
|
19,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
6872,
|
||||||
|
24,
|
||||||
|
3004,
|
||||||
|
20,
|
||||||
|
18,
|
||||||
|
9225,
|
||||||
|
2198,
|
||||||
|
19,
|
||||||
|
12717,
|
||||||
|
103,
|
||||||
|
22,
|
||||||
|
401,
|
||||||
|
24,
|
||||||
|
6348,
|
||||||
|
9,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
1068,
|
||||||
|
2768,
|
||||||
|
2286,
|
||||||
|
19,
|
||||||
|
33,
|
||||||
|
104,
|
||||||
|
19,
|
||||||
|
176,
|
||||||
|
24,
|
||||||
|
9313,
|
||||||
|
19,
|
||||||
|
20086,
|
||||||
|
28,
|
||||||
|
45,
|
||||||
|
10292,
|
||||||
|
9,
|
||||||
|
4,
|
||||||
|
3,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
).long()
|
||||||
|
# In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||||
|
# (except for Alexei and Maria) are discovered.
|
||||||
|
# The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||||
|
# remainder of the story. 1883 Western Siberia,
|
||||||
|
# a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
||||||
|
# Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
||||||
|
# father initially slaps him for making such an accusation, Rasputin watches as the
|
||||||
|
# man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
||||||
|
# the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||||||
|
# with people, even a bishop, begging for his blessing. """
|
||||||
|
|
||||||
|
expected_output_ids = [
|
||||||
|
67,
|
||||||
|
2840,
|
||||||
|
19,
|
||||||
|
18,
|
||||||
|
1484,
|
||||||
|
20,
|
||||||
|
965,
|
||||||
|
29077,
|
||||||
|
8719,
|
||||||
|
1273,
|
||||||
|
21,
|
||||||
|
45,
|
||||||
|
273,
|
||||||
|
17,
|
||||||
|
10,
|
||||||
|
15048,
|
||||||
|
28,
|
||||||
|
27511,
|
||||||
|
21,
|
||||||
|
4185,
|
||||||
|
11,
|
||||||
|
41,
|
||||||
|
2444,
|
||||||
|
9,
|
||||||
|
32,
|
||||||
|
1025,
|
||||||
|
20,
|
||||||
|
8719,
|
||||||
|
26,
|
||||||
|
23,
|
||||||
|
673,
|
||||||
|
966,
|
||||||
|
19,
|
||||||
|
29077,
|
||||||
|
20643,
|
||||||
|
27511,
|
||||||
|
20822,
|
||||||
|
20643,
|
||||||
|
19,
|
||||||
|
17,
|
||||||
|
6616,
|
||||||
|
17511,
|
||||||
|
18,
|
||||||
|
8978,
|
||||||
|
20,
|
||||||
|
18,
|
||||||
|
777,
|
||||||
|
9,
|
||||||
|
19233,
|
||||||
|
1527,
|
||||||
|
17669,
|
||||||
|
19,
|
||||||
|
24,
|
||||||
|
673,
|
||||||
|
17,
|
||||||
|
28756,
|
||||||
|
150,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
27,
|
||||||
|
442,
|
||||||
|
37,
|
||||||
|
45,
|
||||||
|
668,
|
||||||
|
21,
|
||||||
|
24,
|
||||||
|
256,
|
||||||
|
20,
|
||||||
|
416,
|
||||||
|
22,
|
||||||
|
2771,
|
||||||
|
4901,
|
||||||
|
9,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
51,
|
||||||
|
24,
|
||||||
|
3004,
|
||||||
|
21,
|
||||||
|
28142,
|
||||||
|
23,
|
||||||
|
65,
|
||||||
|
20,
|
||||||
|
18,
|
||||||
|
416,
|
||||||
|
34,
|
||||||
|
24,
|
||||||
|
2958,
|
||||||
|
22947,
|
||||||
|
9,
|
||||||
|
1177,
|
||||||
|
45,
|
||||||
|
668,
|
||||||
|
3097,
|
||||||
|
13768,
|
||||||
|
23,
|
||||||
|
103,
|
||||||
|
28,
|
||||||
|
441,
|
||||||
|
148,
|
||||||
|
48,
|
||||||
|
20522,
|
||||||
|
19,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
12860,
|
||||||
|
34,
|
||||||
|
18,
|
||||||
|
326,
|
||||||
|
27,
|
||||||
|
17492,
|
||||||
|
684,
|
||||||
|
21,
|
||||||
|
6709,
|
||||||
|
9,
|
||||||
|
8585,
|
||||||
|
123,
|
||||||
|
266,
|
||||||
|
19,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
6872,
|
||||||
|
24,
|
||||||
|
3004,
|
||||||
|
20,
|
||||||
|
18,
|
||||||
|
9225,
|
||||||
|
2198,
|
||||||
|
19,
|
||||||
|
12717,
|
||||||
|
103,
|
||||||
|
22,
|
||||||
|
401,
|
||||||
|
24,
|
||||||
|
6348,
|
||||||
|
9,
|
||||||
|
12943,
|
||||||
|
4354,
|
||||||
|
153,
|
||||||
|
1068,
|
||||||
|
2768,
|
||||||
|
2286,
|
||||||
|
19,
|
||||||
|
33,
|
||||||
|
104,
|
||||||
|
19,
|
||||||
|
176,
|
||||||
|
24,
|
||||||
|
9313,
|
||||||
|
19,
|
||||||
|
20086,
|
||||||
|
28,
|
||||||
|
45,
|
||||||
|
10292,
|
||||||
|
9,
|
||||||
|
4,
|
||||||
|
3,
|
||||||
|
1722,
|
||||||
|
19,
|
||||||
|
24,
|
||||||
|
6348,
|
||||||
|
61,
|
||||||
|
977,
|
||||||
|
176,
|
||||||
|
1772,
|
||||||
|
33,
|
||||||
|
45,
|
||||||
|
970,
|
||||||
|
19,
|
||||||
|
4185,
|
||||||
|
19,
|
||||||
|
27,
|
||||||
|
442,
|
||||||
|
22,
|
||||||
|
2771,
|
||||||
|
4901,
|
||||||
|
25,
|
||||||
|
18,
|
||||||
|
2059,
|
||||||
|
20,
|
||||||
|
24,
|
||||||
|
303,
|
||||||
|
1775,
|
||||||
|
691,
|
||||||
|
9,
|
||||||
|
1147,
|
||||||
|
19,
|
||||||
|
634,
|
||||||
|
19,
|
||||||
|
43,
|
||||||
|
51,
|
||||||
|
54,
|
||||||
|
6157,
|
||||||
|
2999,
|
||||||
|
33,
|
||||||
|
4185,
|
||||||
|
]
|
||||||
|
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
|
||||||
|
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
|
||||||
|
# narrates the remainder of the story. 1883 Western Siberia, a young Grigori Rasputin
|
||||||
|
# is asked by his father and a group of men to perform magic. Rasputin has a vision and
|
||||||
|
# denounces one of the men as a horse thief. Although his father initially slaps
|
||||||
|
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
|
||||||
|
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
|
||||||
|
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
|
||||||
|
# 1990, a priest who cannot even walk with his wife, Maria, is asked to perform magic
|
||||||
|
# in the presence of a local religious leader.
|
||||||
|
# Since, however, he has had difficulty walking with Maria
|
||||||
|
|
||||||
|
torch.manual_seed(0)
|
||||||
|
output_ids = model.generate(
|
||||||
|
input_ids,
|
||||||
|
bos_token_id=self.special_tokens["bos_token_id"],
|
||||||
|
pad_token_id=self.special_tokens["pad_token_id"],
|
||||||
|
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||||
|
max_length=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user