Switch from return_tuple to return_dict (#6138)
* Switch from return_tuple to return_dict
* Fix test
* [WIP] Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleC… (#5614)
* Test TF Flaubert + Add {XLM, Flaubert}{TokenClassification, MultipleChoice} models and tests
* AutoModels
Tiny tweaks
* Style
* Final changes before merge
* Re-order for simpler review
* Final fixes
* Addressing @sgugger's comments
* Test MultipleChoice
* Rework TF trainer (#6038)
* Fully rework training/prediction loops
* fix method name
* Fix variable name
* Fix property name
* Fix scope
* Fix method name
* Fix tuple index
* Fix tuple index
* Fix indentation
* Fix variable name
* fix eval before log
* Add drop remainder for test dataset
* Fix step number + fix logging datetime
* fix eval loss value
* use global step instead of step + fix logging at step 0
* Fix logging datetime
* Fix global_step usage
* Fix breaking loop + logging datetime
* Fix step in prediction loop
* Fix step breaking
* Fix train/test loops
* Force TF at least 2.2 for the trainer
* Use assert_cardinality to facilitate the dataset size computation
* Log steps per epoch
* Make tfds compliant with TPU
* Make tfds compliant with TPU
* Use TF dataset enumerate instead of the Python one
* revert previous commit
* Fix data_dir
* Apply style
* rebase on master
* Address Sylvain's comments
* Address Sylvain's and Lysandre comments
* Trigger CI
* Remove unused import
* Switch from return_tuple to return_dict
* Fix test
* Add recent model
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Julien Plu <plu.julien@gmail.com>
This commit is contained in:
@@ -295,8 +295,9 @@ CTRL_INPUTS_DOCSTRING = r"""
|
||||
If set to ``True``, the attentions tensors of all attention layers are returned. See ``attentions`` under returned tensors for more detail.
|
||||
output_hidden_states (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the hidden states of all layers are returned. See ``hidden_states`` under returned tensors for more detail.
|
||||
return_tuple (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the output of the model will be a plain tuple instead of a ``dataclass``.
|
||||
return_dict (:obj:`bool`, `optional`, defaults to :obj:`None`):
|
||||
If set to ``True``, the model will return a :class:`~transformers.file_utils.ModelOutput` instead of a
|
||||
plain tuple.
|
||||
"""
|
||||
|
||||
|
||||
@@ -355,7 +356,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
if "past" in kwargs:
|
||||
@@ -371,7 +372,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
@@ -472,7 +473,7 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:]
|
||||
all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions)
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
@@ -526,7 +527,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_tuple=None,
|
||||
return_dict=None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -544,7 +545,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
)
|
||||
past_key_values = kwargs.pop("past")
|
||||
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
|
||||
return_tuple = return_tuple if return_tuple is not None else self.config.use_return_tuple
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
transformer_outputs = self.transformer(
|
||||
input_ids,
|
||||
@@ -557,7 +558,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_tuple=return_tuple,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = transformer_outputs[0]
|
||||
@@ -573,7 +574,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
if return_tuple:
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user