Merge pull request #3191 from patrickvonplaten/add_integration_tests_lm_generate_torch_tf
Add integration tests lm generate torch tf
This commit is contained in:
@@ -408,7 +408,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
||||
inputs_embeds = self.embeddings(input_ids)
|
||||
|
||||
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||
if langs is not None and self.use_lang_emb:
|
||||
if langs is not None and self.use_lang_emb and self.n_langs > 1:
|
||||
tensor = tensor + self.lang_embeddings(langs)
|
||||
if token_type_ids is not None:
|
||||
tensor = tensor + self.embeddings(token_type_ids)
|
||||
|
||||
@@ -219,30 +219,31 @@ class CTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_ctrl(self):
|
||||
model = CTRLLMHeadModel.from_pretrained("ctrl")
|
||||
input_ids = torch.Tensor([[11859, 586, 20984, 8]]).long() # Legal My neighbor is
|
||||
input_ids = torch.tensor(
|
||||
[[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device
|
||||
) # Legal the president is
|
||||
expected_output_ids = [
|
||||
11859,
|
||||
586,
|
||||
20984,
|
||||
0,
|
||||
1611,
|
||||
8,
|
||||
13391,
|
||||
3,
|
||||
980,
|
||||
8258,
|
||||
72,
|
||||
327,
|
||||
148,
|
||||
5,
|
||||
150,
|
||||
26449,
|
||||
2,
|
||||
53,
|
||||
29,
|
||||
226,
|
||||
19,
|
||||
348,
|
||||
469,
|
||||
3,
|
||||
780,
|
||||
49,
|
||||
3,
|
||||
980,
|
||||
] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
|
||||
torch.manual_seed(0)
|
||||
2595,
|
||||
48,
|
||||
20740,
|
||||
246533,
|
||||
246533,
|
||||
19,
|
||||
30,
|
||||
5,
|
||||
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
|
||||
|
||||
output_ids = model.generate(input_ids)
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@@ -223,7 +223,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# append to next input_ids and attn_mask
|
||||
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
|
||||
attn_mask = torch.cat(
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1
|
||||
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
|
||||
)
|
||||
|
||||
# get two different outputs
|
||||
@@ -343,39 +343,36 @@ class GPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_gpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = torch.Tensor([[464, 3290, 318, 13779]]).long() # The dog is cute
|
||||
input_ids = torch.tensor([[464, 3290]], dtype=torch.long, device=torch_device) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
318,
|
||||
13779,
|
||||
1165,
|
||||
13,
|
||||
632,
|
||||
7832,
|
||||
284,
|
||||
6437,
|
||||
319,
|
||||
502,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
318,
|
||||
922,
|
||||
329,
|
||||
502,
|
||||
357,
|
||||
1169,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog is cute too. It likes to rub on me and is good for me (the dog
|
||||
torch.manual_seed(0)
|
||||
|
||||
output_ids = model.generate(input_ids)
|
||||
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = GPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
input_ids = torch.Tensor([[464, 1893]]).long() # The president
|
||||
input_ids = torch.tensor([[464, 1893]], dtype=torch.long, device=torch_device) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
1893,
|
||||
|
||||
@@ -123,7 +123,15 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
|
||||
return config, input_ids, head_mask, token_type_ids, sequence_labels, token_labels, choice_labels
|
||||
return (
|
||||
config,
|
||||
input_ids,
|
||||
head_mask,
|
||||
token_type_ids,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
)
|
||||
|
||||
def check_loss_output(self, result):
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
@@ -139,7 +147,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
result = {"sequence_output": sequence_output}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
|
||||
@@ -153,7 +161,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def create_and_check_double_lm_head_model(self, config, input_ids, head_mask, token_type_ids, *args):
|
||||
@@ -167,7 +175,7 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
@@ -181,7 +189,11 @@ class OpenAIGPTModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "head_mask": head_mask}
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"head_mask": head_mask,
|
||||
}
|
||||
|
||||
return config, inputs_dict
|
||||
|
||||
@@ -215,30 +227,29 @@ class OPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_openai_gpt(self):
|
||||
model = OpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
|
||||
input_ids = torch.Tensor([[481, 2585, 544, 4957]]).long() # The dog is cute
|
||||
input_ids = torch.tensor([[481, 4735, 544]], dtype=torch.long, device=torch_device) # the president is
|
||||
expected_output_ids = [
|
||||
481,
|
||||
2585,
|
||||
4735,
|
||||
544,
|
||||
4957,
|
||||
669,
|
||||
512,
|
||||
761,
|
||||
5990,
|
||||
271,
|
||||
645,
|
||||
246,
|
||||
963,
|
||||
870,
|
||||
762,
|
||||
239,
|
||||
244,
|
||||
40477,
|
||||
244,
|
||||
249,
|
||||
719,
|
||||
881,
|
||||
487,
|
||||
535,
|
||||
976,
|
||||
2479,
|
||||
544,
|
||||
240,
|
||||
487,
|
||||
804,
|
||||
1296,
|
||||
2891,
|
||||
512,
|
||||
] # the dog is cute when you're annoyed : if he's really stupid, he 'll stop fighting you
|
||||
torch.manual_seed(0)
|
||||
244,
|
||||
603,
|
||||
481,
|
||||
] # the president is a very good man. " \n " i\'m sure he is, " said the
|
||||
|
||||
output_ids = model.generate(input_ids)
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
from transformers.modeling_tf_ctrl import TFCTRLModel, TFCTRLLMHeadModel, TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
|
||||
|
||||
@@ -202,3 +203,35 @@ class TFCTRLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in list(TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFCTRLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class TFCTRLModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_ctrl(self):
|
||||
model = TFCTRLLMHeadModel.from_pretrained("ctrl")
|
||||
input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32) # Legal the president is
|
||||
expected_output_ids = [
|
||||
11859,
|
||||
0,
|
||||
1611,
|
||||
8,
|
||||
5,
|
||||
150,
|
||||
26449,
|
||||
2,
|
||||
19,
|
||||
348,
|
||||
469,
|
||||
3,
|
||||
2595,
|
||||
48,
|
||||
20740,
|
||||
246533,
|
||||
246533,
|
||||
19,
|
||||
30,
|
||||
5,
|
||||
] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -328,13 +328,35 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
def prepare_generation_special_tokens():
|
||||
return {"bos_token_id": 50256, "eos_token_id": 50256}
|
||||
|
||||
|
||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
special_tokens = prepare_generation_special_tokens()
|
||||
@slow
|
||||
def test_lm_generate_gpt2(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("gpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 3290]], dtype=tf.int32) # The dog
|
||||
expected_output_ids = [
|
||||
464,
|
||||
3290,
|
||||
373,
|
||||
1043,
|
||||
287,
|
||||
257,
|
||||
2214,
|
||||
1474,
|
||||
262,
|
||||
16246,
|
||||
286,
|
||||
2688,
|
||||
290,
|
||||
2688,
|
||||
27262,
|
||||
13,
|
||||
198,
|
||||
198,
|
||||
464,
|
||||
3290,
|
||||
] # The dog was found in a field near the intersection of West and West Streets.\n\nThe dog
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
@@ -363,11 +385,5 @@ class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
2635,
|
||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
bos_token_id=self.special_tokens["bos_token_id"],
|
||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||
)
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -238,3 +238,35 @@ class TFOpenAIGPTModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in list(TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFOpenAIGPTModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class TFOPENAIGPTModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_openai_gpt(self):
|
||||
model = TFOpenAIGPTLMHeadModel.from_pretrained("openai-gpt")
|
||||
input_ids = tf.convert_to_tensor([[481, 4735, 544]], dtype=tf.int32) # the president is
|
||||
expected_output_ids = [
|
||||
481,
|
||||
4735,
|
||||
544,
|
||||
246,
|
||||
963,
|
||||
870,
|
||||
762,
|
||||
239,
|
||||
244,
|
||||
40477,
|
||||
244,
|
||||
249,
|
||||
719,
|
||||
881,
|
||||
487,
|
||||
544,
|
||||
240,
|
||||
244,
|
||||
603,
|
||||
481,
|
||||
] # the president is a very good man. " \n " i\'m sure he is, " said the
|
||||
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -212,3 +212,366 @@ class TFTransfoXLModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in list(TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFTransfoXLModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class TFTransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_transfo_xl_wt103(self):
|
||||
model = TFTransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
|
||||
input_ids = tf.convert_to_tensor(
|
||||
[
|
||||
[
|
||||
33,
|
||||
1297,
|
||||
2,
|
||||
1,
|
||||
1009,
|
||||
4,
|
||||
1109,
|
||||
11739,
|
||||
4762,
|
||||
358,
|
||||
5,
|
||||
25,
|
||||
245,
|
||||
22,
|
||||
1706,
|
||||
17,
|
||||
20098,
|
||||
5,
|
||||
3215,
|
||||
21,
|
||||
37,
|
||||
1110,
|
||||
3,
|
||||
13,
|
||||
1041,
|
||||
4,
|
||||
24,
|
||||
603,
|
||||
490,
|
||||
2,
|
||||
71477,
|
||||
20098,
|
||||
104447,
|
||||
2,
|
||||
20961,
|
||||
1,
|
||||
2604,
|
||||
4,
|
||||
1,
|
||||
329,
|
||||
3,
|
||||
6224,
|
||||
831,
|
||||
16002,
|
||||
2,
|
||||
8,
|
||||
603,
|
||||
78967,
|
||||
29546,
|
||||
23,
|
||||
803,
|
||||
20,
|
||||
25,
|
||||
416,
|
||||
5,
|
||||
8,
|
||||
232,
|
||||
4,
|
||||
277,
|
||||
6,
|
||||
1855,
|
||||
4601,
|
||||
3,
|
||||
29546,
|
||||
54,
|
||||
8,
|
||||
3609,
|
||||
5,
|
||||
57211,
|
||||
49,
|
||||
4,
|
||||
1,
|
||||
277,
|
||||
18,
|
||||
8,
|
||||
1755,
|
||||
15691,
|
||||
3,
|
||||
341,
|
||||
25,
|
||||
416,
|
||||
693,
|
||||
42573,
|
||||
71,
|
||||
17,
|
||||
401,
|
||||
94,
|
||||
31,
|
||||
17919,
|
||||
2,
|
||||
29546,
|
||||
7873,
|
||||
18,
|
||||
1,
|
||||
435,
|
||||
23,
|
||||
11011,
|
||||
755,
|
||||
5,
|
||||
5167,
|
||||
3,
|
||||
7983,
|
||||
98,
|
||||
84,
|
||||
2,
|
||||
29546,
|
||||
3267,
|
||||
8,
|
||||
3609,
|
||||
4,
|
||||
1,
|
||||
4865,
|
||||
1075,
|
||||
2,
|
||||
6087,
|
||||
71,
|
||||
6,
|
||||
346,
|
||||
8,
|
||||
5854,
|
||||
3,
|
||||
29546,
|
||||
824,
|
||||
1400,
|
||||
1868,
|
||||
2,
|
||||
19,
|
||||
160,
|
||||
2,
|
||||
311,
|
||||
8,
|
||||
5496,
|
||||
2,
|
||||
20920,
|
||||
17,
|
||||
25,
|
||||
15097,
|
||||
3,
|
||||
24,
|
||||
24,
|
||||
0,
|
||||
]
|
||||
],
|
||||
dtype=tf.int31,
|
||||
)
|
||||
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
||||
# ( except for Alexei and Maria ) are discovered .
|
||||
# The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
|
||||
# remainder of the story . 1883 Western Siberia ,
|
||||
# a young Grigori Rasputin is asked by his father and a group of men to perform magic .
|
||||
# Rasputin has a vision and denounces one of the men as a horse thief . Although his
|
||||
# father initially slaps him for making such an accusation , Rasputin watches as the
|
||||
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
|
||||
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
|
||||
# with people , even a bishop , begging for his blessing . <eod> </s> <eos>
|
||||
|
||||
expected_output_ids = [
|
||||
33,
|
||||
1297,
|
||||
2,
|
||||
1,
|
||||
1009,
|
||||
4,
|
||||
1109,
|
||||
11739,
|
||||
4762,
|
||||
358,
|
||||
5,
|
||||
25,
|
||||
245,
|
||||
22,
|
||||
1706,
|
||||
17,
|
||||
20098,
|
||||
5,
|
||||
3215,
|
||||
21,
|
||||
37,
|
||||
1110,
|
||||
3,
|
||||
13,
|
||||
1041,
|
||||
4,
|
||||
24,
|
||||
603,
|
||||
490,
|
||||
2,
|
||||
71477,
|
||||
20098,
|
||||
104447,
|
||||
2,
|
||||
20961,
|
||||
1,
|
||||
2604,
|
||||
4,
|
||||
1,
|
||||
329,
|
||||
3,
|
||||
6224,
|
||||
831,
|
||||
16002,
|
||||
2,
|
||||
8,
|
||||
603,
|
||||
78967,
|
||||
29546,
|
||||
23,
|
||||
803,
|
||||
20,
|
||||
25,
|
||||
416,
|
||||
5,
|
||||
8,
|
||||
232,
|
||||
4,
|
||||
277,
|
||||
6,
|
||||
1855,
|
||||
4601,
|
||||
3,
|
||||
29546,
|
||||
54,
|
||||
8,
|
||||
3609,
|
||||
5,
|
||||
57211,
|
||||
49,
|
||||
4,
|
||||
1,
|
||||
277,
|
||||
18,
|
||||
8,
|
||||
1755,
|
||||
15691,
|
||||
3,
|
||||
341,
|
||||
25,
|
||||
416,
|
||||
693,
|
||||
42573,
|
||||
71,
|
||||
17,
|
||||
401,
|
||||
94,
|
||||
31,
|
||||
17919,
|
||||
2,
|
||||
29546,
|
||||
7873,
|
||||
18,
|
||||
1,
|
||||
435,
|
||||
23,
|
||||
11011,
|
||||
755,
|
||||
5,
|
||||
5167,
|
||||
3,
|
||||
7983,
|
||||
98,
|
||||
84,
|
||||
2,
|
||||
29546,
|
||||
3267,
|
||||
8,
|
||||
3609,
|
||||
4,
|
||||
1,
|
||||
4865,
|
||||
1075,
|
||||
2,
|
||||
6087,
|
||||
71,
|
||||
6,
|
||||
346,
|
||||
8,
|
||||
5854,
|
||||
3,
|
||||
29546,
|
||||
824,
|
||||
1400,
|
||||
1868,
|
||||
2,
|
||||
19,
|
||||
160,
|
||||
2,
|
||||
311,
|
||||
8,
|
||||
5496,
|
||||
2,
|
||||
20920,
|
||||
17,
|
||||
25,
|
||||
15097,
|
||||
3,
|
||||
24,
|
||||
24,
|
||||
0,
|
||||
33,
|
||||
1,
|
||||
1857,
|
||||
2,
|
||||
1,
|
||||
1009,
|
||||
4,
|
||||
1109,
|
||||
11739,
|
||||
4762,
|
||||
358,
|
||||
5,
|
||||
25,
|
||||
245,
|
||||
28,
|
||||
1110,
|
||||
3,
|
||||
13,
|
||||
1041,
|
||||
4,
|
||||
24,
|
||||
603,
|
||||
490,
|
||||
2,
|
||||
71477,
|
||||
20098,
|
||||
104447,
|
||||
2,
|
||||
20961,
|
||||
1,
|
||||
2604,
|
||||
4,
|
||||
1,
|
||||
329,
|
||||
3,
|
||||
0,
|
||||
]
|
||||
# In 1991, the remains of Russian Tsar Nicholas II and his family (
|
||||
# except for Alexei and Maria ) are discovered. The voice of young son,
|
||||
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.
|
||||
# 1883 Western Siberia, a young Grigori Rasputin is asked by his father
|
||||
# and a group of men to perform magic. Rasputin has a vision and
|
||||
# denounces one of the men as a horse thief. Although his father initially
|
||||
# slaps him for making such an accusation, Rasputin watches as the man
|
||||
# is chased outside and beaten. Twenty years later, Rasputin sees a vision
|
||||
# of the Virgin Mary, prompting him to become a priest.
|
||||
# Rasputin quickly becomes famous, with people, even a bishop, begging for
|
||||
# his blessing. <unk> <unk> <eos> In the 1990s, the remains of Russian Tsar
|
||||
# Nicholas II and his family were discovered. The voice of <unk> young son,
|
||||
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
|
||||
|
||||
# TODO: add this test when trasnfo-xl-lmhead is implemented
|
||||
with self.assertRaises(NotImplementedError):
|
||||
model.generate(input_ids, max_length=200, do_sample=False)
|
||||
print(expected_output_ids)
|
||||
# self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids) TODO: (PVP) to add when transfo-xl is implemented
|
||||
|
||||
@@ -311,3 +311,35 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in list(TF_XLM_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFXLMModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class TFXLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_xlm_mlm_en_2048(self):
|
||||
model = TFXLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
|
||||
input_ids = tf.convert_to_tensor([[14, 447]], dtype=tf.int32) # the president
|
||||
expected_output_ids = [
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
] # the president the president the president the president the president the president the president the president the president the president
|
||||
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -413,3 +413,405 @@ class TFXLNetModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in list(TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
class TFXLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_xlnet_base_cased(self):
|
||||
model = TFXLNetLMHeadModel.from_pretrained("xlnet-base-cased")
|
||||
input_ids = tf.convert_to_tensor(
|
||||
[
|
||||
[
|
||||
67,
|
||||
2840,
|
||||
19,
|
||||
18,
|
||||
1484,
|
||||
20,
|
||||
965,
|
||||
29077,
|
||||
8719,
|
||||
1273,
|
||||
21,
|
||||
45,
|
||||
273,
|
||||
17,
|
||||
10,
|
||||
15048,
|
||||
28,
|
||||
27511,
|
||||
21,
|
||||
4185,
|
||||
11,
|
||||
41,
|
||||
2444,
|
||||
9,
|
||||
32,
|
||||
1025,
|
||||
20,
|
||||
8719,
|
||||
26,
|
||||
23,
|
||||
673,
|
||||
966,
|
||||
19,
|
||||
29077,
|
||||
20643,
|
||||
27511,
|
||||
20822,
|
||||
20643,
|
||||
19,
|
||||
17,
|
||||
6616,
|
||||
17511,
|
||||
18,
|
||||
8978,
|
||||
20,
|
||||
18,
|
||||
777,
|
||||
9,
|
||||
19233,
|
||||
1527,
|
||||
17669,
|
||||
19,
|
||||
24,
|
||||
673,
|
||||
17,
|
||||
28756,
|
||||
150,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
442,
|
||||
37,
|
||||
45,
|
||||
668,
|
||||
21,
|
||||
24,
|
||||
256,
|
||||
20,
|
||||
416,
|
||||
22,
|
||||
2771,
|
||||
4901,
|
||||
9,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
51,
|
||||
24,
|
||||
3004,
|
||||
21,
|
||||
28142,
|
||||
23,
|
||||
65,
|
||||
20,
|
||||
18,
|
||||
416,
|
||||
34,
|
||||
24,
|
||||
2958,
|
||||
22947,
|
||||
9,
|
||||
1177,
|
||||
45,
|
||||
668,
|
||||
3097,
|
||||
13768,
|
||||
23,
|
||||
103,
|
||||
28,
|
||||
441,
|
||||
148,
|
||||
48,
|
||||
20522,
|
||||
19,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
12860,
|
||||
34,
|
||||
18,
|
||||
326,
|
||||
27,
|
||||
17492,
|
||||
684,
|
||||
21,
|
||||
6709,
|
||||
9,
|
||||
8585,
|
||||
123,
|
||||
266,
|
||||
19,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
6872,
|
||||
24,
|
||||
3004,
|
||||
20,
|
||||
18,
|
||||
9225,
|
||||
2198,
|
||||
19,
|
||||
12717,
|
||||
103,
|
||||
22,
|
||||
401,
|
||||
24,
|
||||
6348,
|
||||
9,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
1068,
|
||||
2768,
|
||||
2286,
|
||||
19,
|
||||
33,
|
||||
104,
|
||||
19,
|
||||
176,
|
||||
24,
|
||||
9313,
|
||||
19,
|
||||
20086,
|
||||
28,
|
||||
45,
|
||||
10292,
|
||||
9,
|
||||
4,
|
||||
3,
|
||||
]
|
||||
],
|
||||
dtype=tf.int32,
|
||||
)
|
||||
# In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
# (except for Alexei and Maria) are discovered.
|
||||
# The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||
# remainder of the story. 1883 Western Siberia,
|
||||
# a young Grigori Rasputin is asked by his father and a group of men to perform magic.
|
||||
# Rasputin has a vision and denounces one of the men as a horse thief. Although his
|
||||
# father initially slaps him for making such an accusation, Rasputin watches as the
|
||||
# man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
|
||||
# the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous,
|
||||
# with people, even a bishop, begging for his blessing. """
|
||||
|
||||
expected_output_ids = [
|
||||
67,
|
||||
2840,
|
||||
19,
|
||||
18,
|
||||
1484,
|
||||
20,
|
||||
965,
|
||||
29077,
|
||||
8719,
|
||||
1273,
|
||||
21,
|
||||
45,
|
||||
273,
|
||||
17,
|
||||
10,
|
||||
15048,
|
||||
28,
|
||||
27511,
|
||||
21,
|
||||
4185,
|
||||
11,
|
||||
41,
|
||||
2444,
|
||||
9,
|
||||
32,
|
||||
1025,
|
||||
20,
|
||||
8719,
|
||||
26,
|
||||
23,
|
||||
673,
|
||||
966,
|
||||
19,
|
||||
29077,
|
||||
20643,
|
||||
27511,
|
||||
20822,
|
||||
20643,
|
||||
19,
|
||||
17,
|
||||
6616,
|
||||
17511,
|
||||
18,
|
||||
8978,
|
||||
20,
|
||||
18,
|
||||
777,
|
||||
9,
|
||||
19233,
|
||||
1527,
|
||||
17669,
|
||||
19,
|
||||
24,
|
||||
673,
|
||||
17,
|
||||
28756,
|
||||
150,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
442,
|
||||
37,
|
||||
45,
|
||||
668,
|
||||
21,
|
||||
24,
|
||||
256,
|
||||
20,
|
||||
416,
|
||||
22,
|
||||
2771,
|
||||
4901,
|
||||
9,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
51,
|
||||
24,
|
||||
3004,
|
||||
21,
|
||||
28142,
|
||||
23,
|
||||
65,
|
||||
20,
|
||||
18,
|
||||
416,
|
||||
34,
|
||||
24,
|
||||
2958,
|
||||
22947,
|
||||
9,
|
||||
1177,
|
||||
45,
|
||||
668,
|
||||
3097,
|
||||
13768,
|
||||
23,
|
||||
103,
|
||||
28,
|
||||
441,
|
||||
148,
|
||||
48,
|
||||
20522,
|
||||
19,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
12860,
|
||||
34,
|
||||
18,
|
||||
326,
|
||||
27,
|
||||
17492,
|
||||
684,
|
||||
21,
|
||||
6709,
|
||||
9,
|
||||
8585,
|
||||
123,
|
||||
266,
|
||||
19,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
6872,
|
||||
24,
|
||||
3004,
|
||||
20,
|
||||
18,
|
||||
9225,
|
||||
2198,
|
||||
19,
|
||||
12717,
|
||||
103,
|
||||
22,
|
||||
401,
|
||||
24,
|
||||
6348,
|
||||
9,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
1068,
|
||||
2768,
|
||||
2286,
|
||||
19,
|
||||
33,
|
||||
104,
|
||||
19,
|
||||
176,
|
||||
24,
|
||||
9313,
|
||||
19,
|
||||
20086,
|
||||
28,
|
||||
45,
|
||||
10292,
|
||||
9,
|
||||
4,
|
||||
3,
|
||||
19,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
442,
|
||||
22,
|
||||
2771,
|
||||
4901,
|
||||
9,
|
||||
69,
|
||||
27,
|
||||
50,
|
||||
551,
|
||||
22,
|
||||
2771,
|
||||
4901,
|
||||
19,
|
||||
21,
|
||||
45,
|
||||
668,
|
||||
21,
|
||||
18,
|
||||
416,
|
||||
41,
|
||||
1499,
|
||||
22,
|
||||
755,
|
||||
18,
|
||||
14285,
|
||||
9,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
1499,
|
||||
22,
|
||||
642,
|
||||
22,
|
||||
]
|
||||
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
|
||||
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
|
||||
# narrates the remainder of the story. 1883 Western Siberia, a young Grigori Rasputin
|
||||
# is asked by his father and a group of men to perform magic. Rasputin has a vision and
|
||||
# denounces one of the men as a horse thief. Although his father initially slaps
|
||||
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
|
||||
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
|
||||
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
|
||||
# <sep><cls>, Rasputin is asked to perform magic.
|
||||
# He is not able to perform magic, and his father and
|
||||
# the men are forced to leave the monastery. Rasputin is forced to return to
|
||||
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -218,7 +218,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_transfo_xl_wt103(self):
|
||||
model = TransfoXLLMHeadModel.from_pretrained("transfo-xl-wt103")
|
||||
input_ids = torch.Tensor(
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
33,
|
||||
@@ -363,8 +363,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
24,
|
||||
0,
|
||||
]
|
||||
]
|
||||
).long()
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
# In 1991 , the remains of Russian Tsar Nicholas II and his family
|
||||
# ( except for Alexei and Maria ) are discovered .
|
||||
# The voice of Nicholas's young son , Tsarevich Alexei Nikolaevich , narrates the
|
||||
@@ -374,6 +376,7 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
# father initially slaps him for making such an accusation , Rasputin watches as the
|
||||
# man is chased outside and beaten . Twenty years later , Rasputin sees a vision of
|
||||
# the Virgin Mary , prompting him to become a priest . Rasputin quickly becomes famous ,
|
||||
|
||||
# with people , even a bishop , begging for his blessing . <eod> </s> <eos>
|
||||
|
||||
expected_output_ids = [
|
||||
@@ -518,20 +521,10 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
24,
|
||||
24,
|
||||
0,
|
||||
29546,
|
||||
40,
|
||||
1092,
|
||||
18,
|
||||
8,
|
||||
5854,
|
||||
7,
|
||||
1143,
|
||||
2,
|
||||
7,
|
||||
33,
|
||||
1,
|
||||
159,
|
||||
99,
|
||||
16,
|
||||
1857,
|
||||
2,
|
||||
1,
|
||||
1009,
|
||||
4,
|
||||
@@ -545,14 +538,23 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
28,
|
||||
1110,
|
||||
3,
|
||||
57,
|
||||
629,
|
||||
38,
|
||||
3493,
|
||||
47,
|
||||
1094,
|
||||
7,
|
||||
1297,
|
||||
13,
|
||||
1041,
|
||||
4,
|
||||
24,
|
||||
603,
|
||||
490,
|
||||
2,
|
||||
71477,
|
||||
20098,
|
||||
104447,
|
||||
2,
|
||||
20961,
|
||||
1,
|
||||
2604,
|
||||
4,
|
||||
1,
|
||||
329,
|
||||
3,
|
||||
0,
|
||||
]
|
||||
@@ -566,10 +568,9 @@ class TransfoXLModelLanguageGenerationTest(unittest.TestCase):
|
||||
# is chased outside and beaten. Twenty years later, Rasputin sees a vision
|
||||
# of the Virgin Mary, prompting him to become a priest.
|
||||
# Rasputin quickly becomes famous, with people, even a bishop, begging for
|
||||
# his blessing. Rasputin first appears as a priest in 1996, in the same year
|
||||
# that the remains of Russian Tsar Nicholas II and his family were discovered. H
|
||||
# his blessing. <unk> <unk> <eos> In the 1990s, the remains of Russian Tsar
|
||||
# Nicholas II and his family were discovered. The voice of <unk> young son,
|
||||
# Tsarevich Alexei Nikolaevich, narrates the remainder of the story.<eos>
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
output_ids = model.generate(input_ids, max_length=200)
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
@@ -403,31 +403,29 @@ class XLMModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_xlm_mlm_en_2048(self):
|
||||
model = XLMWithLMHeadModel.from_pretrained("xlm-mlm-en-2048")
|
||||
input_ids = torch.Tensor([[1, 14, 2232, 26, 1]]).long() # The dog is cute
|
||||
input_ids = torch.tensor([[14, 447]], dtype=torch.long, device=torch_device) # the president
|
||||
expected_output_ids = [
|
||||
1,
|
||||
14,
|
||||
2232,
|
||||
26,
|
||||
1,
|
||||
567,
|
||||
26,
|
||||
32,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
149,
|
||||
] # The dog is nothing is it!!!!!!!!!!!! TODO (PVP): this sentence (and others I tried) does not make much sense, there seems to be a problem with xlm language generation.
|
||||
torch.manual_seed(0)
|
||||
|
||||
output_ids = model.generate(input_ids)
|
||||
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
14,
|
||||
447,
|
||||
] # the president the president the president the president the president the president the president the president the president the president
|
||||
# TODO(PVP): this and other input_ids I tried for generation give pretty bad results. Not sure why. Model might just not be made for auto-regressive inference
|
||||
output_ids = model.generate(input_ids, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
@@ -119,11 +119,11 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
input_ids_q = ids_tensor([self.batch_size, self.seq_length + 1], self.vocab_size)
|
||||
perm_mask = torch.zeros(
|
||||
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device
|
||||
self.batch_size, self.seq_length + 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
|
||||
)
|
||||
perm_mask[:, :, -1] = 1.0 # Previous tokens don't see last token
|
||||
target_mapping = torch.zeros(
|
||||
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device
|
||||
self.batch_size, 1, self.seq_length + 1, dtype=torch.float, device=torch_device,
|
||||
)
|
||||
target_mapping[:, 0, -1] = 1.0 # predict last token
|
||||
|
||||
@@ -212,7 +212,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
self.parent.assertEqual(len(no_mems_outputs), 1)
|
||||
|
||||
self.parent.assertListEqual(
|
||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -283,7 +283,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_1"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["all_logits_1"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -292,7 +292,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss_2"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["all_logits_2"].size()), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_2"]),
|
||||
@@ -319,7 +319,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
model.eval()
|
||||
|
||||
outputs = model(input_ids_1)
|
||||
start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems = outputs
|
||||
(start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits, mems,) = outputs
|
||||
|
||||
outputs = model(
|
||||
input_ids_1,
|
||||
@@ -340,7 +340,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
total_loss, mems = outputs
|
||||
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels)
|
||||
outputs = model(input_ids_1, start_positions=sequence_labels, end_positions=sequence_labels,)
|
||||
|
||||
total_loss, mems = outputs
|
||||
|
||||
@@ -356,10 +356,10 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top]
|
||||
list(result["start_top_log_probs"].size()), [self.batch_size, model.config.start_n_top],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top]
|
||||
list(result["start_top_index"].size()), [self.batch_size, model.config.start_n_top],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(result["end_top_log_probs"].size()),
|
||||
@@ -405,7 +405,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size]
|
||||
list(result["logits"].size()), [self.batch_size, self.seq_length, self.type_sequence_label_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -442,7 +442,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
self.parent.assertListEqual(list(result["loss"].size()), [])
|
||||
self.parent.assertListEqual(
|
||||
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size]
|
||||
list(result["logits"].size()), [self.batch_size, self.type_sequence_label_size],
|
||||
)
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.size()) for mem in result["mems_1"]),
|
||||
@@ -517,7 +517,7 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_lm_generate_xlnet_base_cased(self):
|
||||
model = XLNetLMHeadModel.from_pretrained("xlnet-base-cased")
|
||||
input_ids = torch.Tensor(
|
||||
input_ids = torch.tensor(
|
||||
[
|
||||
[
|
||||
67,
|
||||
@@ -682,8 +682,10 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
4,
|
||||
3,
|
||||
]
|
||||
]
|
||||
).long()
|
||||
],
|
||||
dtype=torch.long,
|
||||
device=torch_device,
|
||||
)
|
||||
# In 1991, the remains of Russian Tsar Nicholas II and his family
|
||||
# (except for Alexei and Maria) are discovered.
|
||||
# The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the
|
||||
@@ -857,45 +859,45 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
9,
|
||||
4,
|
||||
3,
|
||||
1722,
|
||||
19,
|
||||
24,
|
||||
6348,
|
||||
61,
|
||||
977,
|
||||
176,
|
||||
1772,
|
||||
33,
|
||||
45,
|
||||
970,
|
||||
19,
|
||||
4185,
|
||||
19,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
442,
|
||||
22,
|
||||
2771,
|
||||
4901,
|
||||
25,
|
||||
18,
|
||||
2059,
|
||||
20,
|
||||
24,
|
||||
303,
|
||||
1775,
|
||||
691,
|
||||
9,
|
||||
1147,
|
||||
69,
|
||||
27,
|
||||
50,
|
||||
551,
|
||||
22,
|
||||
2771,
|
||||
4901,
|
||||
19,
|
||||
634,
|
||||
19,
|
||||
43,
|
||||
51,
|
||||
54,
|
||||
6157,
|
||||
2999,
|
||||
33,
|
||||
4185,
|
||||
21,
|
||||
45,
|
||||
668,
|
||||
21,
|
||||
18,
|
||||
416,
|
||||
41,
|
||||
1499,
|
||||
22,
|
||||
755,
|
||||
18,
|
||||
14285,
|
||||
9,
|
||||
12943,
|
||||
4354,
|
||||
153,
|
||||
27,
|
||||
1499,
|
||||
22,
|
||||
642,
|
||||
22,
|
||||
]
|
||||
# In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria)
|
||||
# are discovered. The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich,
|
||||
@@ -905,11 +907,9 @@ class XLNetModelLanguageGenerationTest(unittest.TestCase):
|
||||
# him for making such an accusation, Rasputin watches as the man is chased outside and beaten.
|
||||
# Twenty years later, Rasputin sees a vision of the Virgin Mary, prompting him to become a priest.
|
||||
# Rasputin quickly becomes famous, with people, even a bishop, begging for his blessing.
|
||||
# 1990, a priest who cannot even walk with his wife, Maria, is asked to perform magic
|
||||
# in the presence of a local religious leader.
|
||||
# Since, however, he has had difficulty walking with Maria
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_ids = model.generate(input_ids, max_length=200)
|
||||
# <sep><cls>, Rasputin is asked to perform magic.
|
||||
# He is not able to perform magic, and his father and
|
||||
# the men are forced to leave the monastery. Rasputin is forced to return to
|
||||
|
||||
output_ids = model.generate(input_ids, max_length=200, do_sample=False)
|
||||
self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user