Add generate() functionality to TF 2.0 (#3063)

* add first copy past test to tf 2 generate

* add tf top_k_top_p_filter fn

* add generate function for TF

* add generate function for TF

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* implemented generate for all models expect transfoXL

* make style

* change permission of test file to correct ones

* delete ipdb

* delete ipdb

* fix bug and finish simple gpt2 integration test

* clean test file

* clean test file

* make style

* make style

* make style

* make style

* change import style

* change import style

* make style

* make style

* add decorators

* add decorators

* fix tf ctrl bug dim => axis in TF

* make style

* make style

* refactored test file

* refactored test file

* take out test_torch_tf_conversion if nothing is defined

* take out test_torch_tf_conversion if nothing is defined

* remove useless files

* remove useless files

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* solve conflicts

* solve conflicts

* fix conflicts

* fix conflicts

* merge conflicts

* delete ipdb

* exposed top_k_top_p_filtering fns

* delete weirdly created w! file

* add comment to test tf common modeling

* fix conflicts

* fix conflicts

* make style

* merge conflicts

* make style

* change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
Patrick von Platen
2020-03-03 15:42:15 +01:00
committed by GitHub
parent b31f715019
commit 4134100363
20 changed files with 892 additions and 62 deletions

View File

@@ -454,14 +454,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if "past" in kwargs and kwargs["past"]:
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
inputs = {"input_ids": input_ids}
inputs.update(kwargs)
return inputs
return {"input_ids": input_ids, "past": past}
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
def forward(