Add generate() functionality to TF 2.0 (#3063)

* add first copy past test to tf 2 generate

* add tf top_k_top_p_filter fn

* add generate function for TF

* add generate function for TF

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* make style

* change permission of test file to correct ones

* delete ipdb

* delete ipdb

* fix bug and finish simple gpt2 integration test

* clean test file

* clean test file

* make style

* make style

* make style

* make style

* change import style

* change import style

* make style

* make style

* add decorators

* add decorators

* fix tf ctrl bug dim => axis in TF

* make style

* make style

* refactored test file

* refactored test file

* take out test_torch_tf_conversion if nothing is defined

* take out test_torch_tf_conversion if nothing is defined

* remove useless files

* remove useless files

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* solve conflicts

* solve conflicts

* fix conflicts

* fix conflicts

* merge conflicts

* delete ipdb

* exposed top_k_top_p_filtering fns

* delete weirdly created w! file

* add comment to test tf common modeling

* fix conflicts

* fix conflicts

* make style

* merge conflicts

* make style

* change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
Patrick von Platen
2020-03-03 15:42:15 +01:00
committed by GitHub
parent b31f715019
commit 4134100363
20 changed files with 892 additions and 62 deletions

View File

@@ -43,6 +43,9 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
all_generative_model_classes = (
(TFXLMWithLMHeadModel,) if is_tf_available() else ()
) # TODO (PVP): Check other models whether language generation is also applicable
class TFXLMModelTester(object):
def __init__(
@@ -75,6 +78,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
summary_type="last",
use_proj=True,
scope=None,
bos_token_id=0,
):
self.parent = parent
self.batch_size = batch_size
@@ -105,6 +109,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
self.num_labels = num_labels
self.num_choices = num_choices
self.scope = scope
self.bos_token_id = bos_token_id
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
@@ -145,6 +150,7 @@ class TFXLMModelTest(TFModelTesterMixin, unittest.TestCase):
initializer_range=self.initializer_range,
summary_type=self.summary_type,
use_proj=self.use_proj,
bos_token_id=self.bos_token_id,
)
return (