Add generate() functionality to TF 2.0 (#3063)
* add first copy past test to tf 2 generate * add tf top_k_top_p_filter fn * add generate function for TF * add generate function for TF * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * implemented generate for all models expect transfoXL * make style * change permission of test file to correct ones * delete ipdb * delete ipdb * fix bug and finish simple gpt2 integration test * clean test file * clean test file * make style * make style * make style * make style * change import style * change import style * make style * make style * add decorators * add decorators * fix tf ctrl bug dim => axis in TF * make style * make style * refactored test file * refactored test file * take out test_torch_tf_conversion if nothing is defined * take out test_torch_tf_conversion if nothing is defined * remove useless files * remove useless files * fix conflicts * fix conflicts * fix conflicts * fix conflicts * fix conflicts * solve conflicts * solve conflicts * fix conflicts * fix conflicts * merge conflicts * delete ipdb * exposed top_k_top_p_filtering fns * delete weirdly created w! file * add comment to test tf common modeling * fix conflicts * fix conflicts * make style * merge conflicts * make style * change tf.tensor.shape to shape_list(tensor)
This commit is contained in:
committed by
GitHub
parent
b31f715019
commit
4134100363
@@ -454,14 +454,12 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if "past" in kwargs and kwargs["past"]:
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
inputs = {"input_ids": input_ids}
|
||||
inputs.update(kwargs)
|
||||
return inputs
|
||||
return {"input_ids": input_ids, "past": past}
|
||||
|
||||
@add_start_docstrings_to_callable(CTRL_INPUTS_DOCSTRING)
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user