Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -37,7 +37,7 @@ if is_tf_available():
|
||||
class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel, TFGPT2DoubleHeadsModel) if is_tf_available() else ()
|
||||
# all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
|
||||
all_generative_model_classes = (TFGPT2LMHeadModel,) if is_tf_available() else ()
|
||||
|
||||
class TFGPT2ModelTester(object):
|
||||
def __init__(
|
||||
@@ -89,6 +89,8 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
self.num_labels = num_labels
|
||||
self.num_choices = num_choices
|
||||
self.scope = scope
|
||||
self.bos_token_id = vocab_size - 1
|
||||
self.eos_token_id = vocab_size - 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
@@ -123,9 +125,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
# hidden_dropout_prob=self.hidden_dropout_prob,
|
||||
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
|
||||
n_positions=self.max_position_embeddings,
|
||||
n_ctx=self.max_position_embeddings
|
||||
n_ctx=self.max_position_embeddings,
|
||||
# type_vocab_size=self.type_vocab_size,
|
||||
# initializer_range=self.initializer_range
|
||||
bos_token_id=self.bos_token_id,
|
||||
eos_token_ids=self.eos_token_id,
|
||||
)
|
||||
|
||||
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
|
||||
@@ -144,7 +148,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
|
||||
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2Model(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
sequence_output = model(inputs)[0]
|
||||
|
||||
inputs = [input_ids, None, input_mask] # None is the input for 'past'
|
||||
@@ -156,18 +164,22 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
"sequence_output": sequence_output.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
|
||||
list(result["sequence_output"].shape), [self.batch_size, self.seq_length, self.hidden_size],
|
||||
)
|
||||
|
||||
def create_and_check_gpt2_lm_head(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
|
||||
model = TFGPT2LMHeadModel(config=config)
|
||||
inputs = {"input_ids": input_ids, "attention_mask": input_mask, "token_type_ids": token_type_ids}
|
||||
inputs = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
prediction_scores = model(inputs)[0]
|
||||
result = {
|
||||
"prediction_scores": prediction_scores.numpy(),
|
||||
}
|
||||
self.parent.assertListEqual(
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
|
||||
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size],
|
||||
)
|
||||
|
||||
def create_and_check_gpt2_double_head(
|
||||
@@ -188,7 +200,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
lm_logits, mc_logits = model(inputs)[:2]
|
||||
result = {"lm_logits": lm_logits.numpy(), "mc_logits": mc_logits.numpy()}
|
||||
self.parent.assertListEqual(
|
||||
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size]
|
||||
list(result["lm_logits"].shape), [self.batch_size, self.num_choices, self.seq_length, self.vocab_size],
|
||||
)
|
||||
self.parent.assertListEqual(list(result["mc_logits"].shape), [self.batch_size, self.num_choices])
|
||||
|
||||
@@ -207,7 +219,11 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
inputs_dict = {
|
||||
"input_ids": input_ids,
|
||||
"token_type_ids": token_type_ids,
|
||||
"attention_mask": input_mask,
|
||||
}
|
||||
return config, inputs_dict
|
||||
|
||||
def setUp(self):
|
||||
@@ -234,3 +250,48 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_name in list(TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
model = TFGPT2Model.from_pretrained(model_name, cache_dir=CACHE_DIR)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
def prepare_generation_special_tokens():
|
||||
return {"bos_token_id": 50256, "eos_token_id": 50256}
|
||||
|
||||
|
||||
class TFGPT2ModelLanguageGenerationTest(unittest.TestCase):
|
||||
|
||||
special_tokens = prepare_generation_special_tokens()
|
||||
|
||||
@slow
|
||||
def test_lm_generate_distilgpt2(self):
|
||||
model = TFGPT2LMHeadModel.from_pretrained("distilgpt2")
|
||||
input_ids = tf.convert_to_tensor([[464, 1893]], dtype=tf.int32) # The president
|
||||
expected_output_ids = [
|
||||
464,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
1829,
|
||||
11,
|
||||
290,
|
||||
262,
|
||||
1893,
|
||||
286,
|
||||
262,
|
||||
1578,
|
||||
7526,
|
||||
11,
|
||||
423,
|
||||
587,
|
||||
287,
|
||||
262,
|
||||
2635,
|
||||
] # The president of the United States, and the president of the United Kingdom, have been in the White
|
||||
|
||||
output_ids = model.generate(
|
||||
input_ids,
|
||||
do_sample=False,
|
||||
bos_token_id=self.special_tokens["bos_token_id"],
|
||||
eos_token_ids=self.special_tokens["eos_token_id"],
|
||||
)
|
||||
|
||||
self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
|
||||
|
||||
Reference in New Issue
Block a user