[Generate] Remove attention_mask and integrate model_main_input_name (#14856)
* up * save * correct * up * correct more * up * up * up * up * up * correct * fix tf * fix * remove tokenizer
This commit is contained in:
committed by
GitHub
parent
86b40073e9
commit
fe4197ab11
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import inspect
|
||||||
import warnings
|
import warnings
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
||||||
@@ -349,9 +350,6 @@ BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOu
|
|||||||
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
|
BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
|
||||||
|
|
||||||
|
|
||||||
ENCODER_MODEL_INPUT_NAMES = ["input_ids", "inputs_embeds", "input_values", "input_features", "pixel_values"]
|
|
||||||
|
|
||||||
|
|
||||||
class GenerationMixin:
|
class GenerationMixin:
|
||||||
"""
|
"""
|
||||||
A class containing all of the functions supporting generation, to be used as a mixin in
|
A class containing all of the functions supporting generation, to be used as a mixin in
|
||||||
@@ -363,58 +361,69 @@ class GenerationMixin:
|
|||||||
inputs: Optional[torch.Tensor] = None,
|
inputs: Optional[torch.Tensor] = None,
|
||||||
bos_token_id: Optional[int] = None,
|
bos_token_id: Optional[int] = None,
|
||||||
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
model_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Tuple[torch.Tensor, Optional[str]]:
|
) -> Tuple[torch.Tensor, Optional[str], Dict[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
This function extracts the model-specific `inputs` for generation.
|
This function extracts the model-specific `inputs` for generation.
|
||||||
"""
|
"""
|
||||||
# filter model input names that are `None`
|
# 1. retrieve all kwargs that are non-None or non-model input related.
|
||||||
model_kwargs = {k: v for k, v in model_kwargs.items() if k not in ENCODER_MODEL_INPUT_NAMES or v is not None}
|
# some encoder-decoder models have different names for model and encoder
|
||||||
# extract keyword arguments that are model input specific
|
if (
|
||||||
model_input_kwarg_names = set(ENCODER_MODEL_INPUT_NAMES) & set(model_kwargs.keys())
|
self.config.is_encoder_decoder
|
||||||
|
and hasattr(self, "encoder")
|
||||||
|
and self.encoder.main_input_name != self.main_input_name
|
||||||
|
):
|
||||||
|
input_name = self.encoder.main_input_name
|
||||||
|
else:
|
||||||
|
input_name = self.main_input_name
|
||||||
|
|
||||||
# There are 5 possible scenarios
|
model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
|
||||||
if inputs is not None and len(model_input_kwarg_names) == 0:
|
|
||||||
# 1. `inputs` are passed and no model-specific keyword inputs
|
# 2. check whether model_input_name is passed as kwarg
|
||||||
# -> return input
|
# if yes and `inputs` is None use kwarg inputs
|
||||||
model_input_name = None
|
inputs_kwarg = model_kwargs.pop(input_name, None)
|
||||||
return inputs, model_input_name, model_kwargs
|
if inputs_kwarg is not None and inputs is not None:
|
||||||
elif inputs is not None and len(model_input_kwarg_names) > 0:
|
|
||||||
# 2. `inputs` are passed as well as model-specific keyword inputs
|
|
||||||
# -> not allowed, raise Error
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"`inputs`: {inputs}` were passed alongside "
|
f"`inputs`: {inputs}` were passed alongside "
|
||||||
f"{model_input_kwarg_names} which is not allowed."
|
f"{input_name} which is not allowed."
|
||||||
f"Make sure to not pass any of {model_input_kwarg_names} "
|
f"Make sure to either pass {inputs} or {input_name}=..."
|
||||||
"when `inputs` is defined."
|
|
||||||
)
|
)
|
||||||
elif inputs is None and len(model_input_kwarg_names) == 0:
|
elif inputs_kwarg is not None:
|
||||||
# 3. no `inputs` and no model-specific keyword inputs are passed
|
inputs = inputs_kwarg
|
||||||
# -> try to create `input_ids` from BOS
|
|
||||||
input_tensor = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
|
||||||
return input_tensor, "input_ids", model_kwargs
|
|
||||||
elif inputs is None and len(model_input_kwarg_names) == 1:
|
|
||||||
# 4. no `inputs` are passed and exactly one model-specific keyword input
|
|
||||||
# -> return that model-specific keyword input tensor
|
|
||||||
model_input_name = model_input_kwarg_names.pop()
|
|
||||||
input_tensor = model_kwargs.pop(model_input_name)
|
|
||||||
|
|
||||||
# make sure model is encoder decoder if not `input_ids`
|
# 3. models with `input_ids` can also make use of `inputs_embeds`
|
||||||
if not self.config.is_encoder_decoder and model_input_name != "input_ids":
|
if self._can_retrieve_inputs_from_name(inputs, "inputs_embeds", model_kwargs):
|
||||||
raise ValueError(
|
inputs, input_name = model_kwargs["inputs_embeds"], "inputs_embeds"
|
||||||
f"If {model_input_name} is passed as model-specific keyword "
|
|
||||||
"input then model has to be an encoder-decoder and not a "
|
# 4. Only encoder-decoder models can have non `input_ids` input format
|
||||||
f"{self.__class__.__name__}."
|
if not self.config.is_encoder_decoder and input_name != "input_ids":
|
||||||
)
|
|
||||||
return input_tensor, model_input_name, model_kwargs
|
|
||||||
else:
|
|
||||||
# 5. no `inputs` are passed and multiple model-specific keyword inputs
|
|
||||||
# -> not allowed, raise Error
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Can only pass one of {ENCODER_MODEL_INPUT_NAMES}, "
|
f"If {input_name} is passed as model-specific keyword "
|
||||||
f"but passed {model_input_kwarg_names}."
|
"input then model has to be an encoder-decoder and not a "
|
||||||
f"Make sure to only pass one of {model_input_kwarg_names}."
|
f"{self.__class__.__name__}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 5. if `inputs` is still None, try to create `input_ids` from BOS token
|
||||||
|
if inputs is None:
|
||||||
|
inputs = self._prepare_input_ids_for_generation(bos_token_id, model_kwargs.get("encoder_outputs"))
|
||||||
|
|
||||||
|
return inputs, input_name, model_kwargs
|
||||||
|
|
||||||
|
def _can_retrieve_inputs_from_name(
|
||||||
|
self, inputs: Optional[torch.Tensor], name: str, model_kwargs: Dict[str, torch.Tensor]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
If `inputs` is None and `name` is in both forward function and keyword
|
||||||
|
arguments, then inputs can be retrieved from name
|
||||||
|
"""
|
||||||
|
can_retrieve_inputs = model_kwargs.get(name, None) is not None and name in set(
|
||||||
|
inspect.signature(self.forward).parameters.keys()
|
||||||
|
)
|
||||||
|
|
||||||
|
if can_retrieve_inputs and inputs is not None:
|
||||||
|
raise ValueError(f"Cannot only pass one of {name} and {self.main_input_name}")
|
||||||
|
|
||||||
|
return can_retrieve_inputs
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
|
def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the
|
Implement in subclasses of [`PreTrainedModel`] for custom behavior to prepare inputs in the
|
||||||
@@ -461,29 +470,22 @@ class GenerationMixin:
|
|||||||
def _prepare_encoder_decoder_kwargs_for_generation(
|
def _prepare_encoder_decoder_kwargs_for_generation(
|
||||||
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
self, inputs_tensor: torch.Tensor, model_kwargs, model_input_name: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
if "encoder_outputs" not in model_kwargs:
|
# 1. get encoder
|
||||||
# 1. get encoder
|
encoder = self.get_encoder()
|
||||||
encoder = self.get_encoder()
|
|
||||||
# 2. prepare encoder args and encoder kwargs from model kwargs
|
|
||||||
encoder_args = (inputs_tensor,)
|
|
||||||
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
|
||||||
encoder_kwargs = {
|
|
||||||
argument: value
|
|
||||||
for argument, value in model_kwargs.items()
|
|
||||||
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
|
||||||
}
|
|
||||||
# 3. make sure that encoder returns `ModelOutput`
|
|
||||||
encoder_kwargs["return_dict"] = True
|
|
||||||
|
|
||||||
# 4. if model_input_name is not defined then pass input_tensor as
|
# 2. prepare encoder args and encoder kwargs from model kwargs
|
||||||
# first input argument and remove from args
|
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
|
||||||
if model_input_name is not None:
|
encoder_kwargs = {
|
||||||
# make sure inputs_tensor is None in case model
|
argument: value
|
||||||
# accepts multiple model input arguments
|
for argument, value in model_kwargs.items()
|
||||||
encoder_kwargs[model_input_name] = inputs_tensor
|
if not any(argument.startswith(p) for p in irrelevant_prefix)
|
||||||
encoder_args = ()
|
}
|
||||||
|
|
||||||
model_kwargs["encoder_outputs"]: ModelOutput = encoder(*encoder_args, **encoder_kwargs)
|
# 3. make sure that encoder returns `ModelOutput`
|
||||||
|
model_input_name = model_input_name if model_input_name is not None else self.main_input_name
|
||||||
|
encoder_kwargs["return_dict"] = True
|
||||||
|
encoder_kwargs[model_input_name] = inputs_tensor
|
||||||
|
model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)
|
||||||
|
|
||||||
return model_kwargs
|
return model_kwargs
|
||||||
|
|
||||||
@@ -1013,12 +1015,13 @@ class GenerationMixin:
|
|||||||
model_kwargs["output_hidden_states"] = output_hidden_states
|
model_kwargs["output_hidden_states"] = output_hidden_states
|
||||||
model_kwargs["use_cache"] = use_cache
|
model_kwargs["use_cache"] = use_cache
|
||||||
|
|
||||||
if model_kwargs.get("attention_mask", None) is None:
|
has_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
|
||||||
|
if model_kwargs.get("attention_mask", None) is None and has_attention_mask:
|
||||||
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
|
||||||
inputs_tensor, pad_token_id, eos_token_id
|
inputs_tensor, pad_token_id, eos_token_id
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.is_encoder_decoder:
|
if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
|
||||||
# if model is encoder decoder encoder_outputs are created
|
# if model is encoder decoder encoder_outputs are created
|
||||||
# and added to `model_kwargs`
|
# and added to `model_kwargs`
|
||||||
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
||||||
|
|||||||
@@ -57,8 +57,6 @@ class KerasMetricCallback(Callback):
|
|||||||
Validation data to be used to generate predictions for the `metric_fn`.
|
Validation data to be used to generate predictions for the `metric_fn`.
|
||||||
metric_fn_kwargs (`dict`, *optional*):
|
metric_fn_kwargs (`dict`, *optional*):
|
||||||
Additional keyword arguments to be passed to the metric_fn.
|
Additional keyword arguments to be passed to the metric_fn.
|
||||||
tokenizer ([`PretrainedTokenizerBase`], *optional*):
|
|
||||||
Tokenizer used to validate column names to be passed to the generate() function.
|
|
||||||
output_cols (`List[str], *optional*):
|
output_cols (`List[str], *optional*):
|
||||||
A list of columns to be retained from the model output as the predictions. Defaults to all.
|
A list of columns to be retained from the model output as the predictions. Defaults to all.
|
||||||
label_cols ('`List[str]`, *optional*'):
|
label_cols ('`List[str]`, *optional*'):
|
||||||
@@ -75,7 +73,6 @@ class KerasMetricCallback(Callback):
|
|||||||
self,
|
self,
|
||||||
metric_fn: Callable,
|
metric_fn: Callable,
|
||||||
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
|
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
|
||||||
tokenizer: Optional[PreTrainedTokenizerBase] = None,
|
|
||||||
metric_fn_kwargs: Optional[dict] = None,
|
metric_fn_kwargs: Optional[dict] = None,
|
||||||
output_cols: Optional[List[str]] = None,
|
output_cols: Optional[List[str]] = None,
|
||||||
label_cols: Optional[List[str]] = None,
|
label_cols: Optional[List[str]] = None,
|
||||||
@@ -97,10 +94,11 @@ class KerasMetricCallback(Callback):
|
|||||||
self.predict_with_generate = predict_with_generate
|
self.predict_with_generate = predict_with_generate
|
||||||
self.output_cols = output_cols
|
self.output_cols = output_cols
|
||||||
self.metric_fn_kwargs = metric_fn_kwargs or dict()
|
self.metric_fn_kwargs = metric_fn_kwargs or dict()
|
||||||
if tokenizer is not None:
|
|
||||||
self.model_input_names = tokenizer.model_input_names
|
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
|
||||||
|
self.main_input_name = self.model.encoder.main_input_name
|
||||||
else:
|
else:
|
||||||
self.model_input_names = ["input_ids"]
|
self.main_input_name = self.model.main_input_name
|
||||||
|
|
||||||
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
|
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
|
||||||
# that is passed to the metric_fn
|
# that is passed to the metric_fn
|
||||||
@@ -161,9 +159,13 @@ class KerasMetricCallback(Callback):
|
|||||||
labels = None
|
labels = None
|
||||||
if self.predict_with_generate:
|
if self.predict_with_generate:
|
||||||
if isinstance(batch, dict):
|
if isinstance(batch, dict):
|
||||||
# generate() gets stressed out by any unexpected keys
|
generation_inputs = batch[self.main_input_name]
|
||||||
batch = {key: array for key, array in batch.items() if key in self.model_input_names}
|
attention_mask = batch.get("attention_mask", None)
|
||||||
predictions = self.model.generate(batch)
|
else:
|
||||||
|
generation_inputs = batch
|
||||||
|
attention_mask = None
|
||||||
|
|
||||||
|
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
|
||||||
else:
|
else:
|
||||||
predictions = self.model.predict(batch)
|
predictions = self.model.predict(batch)
|
||||||
predictions = dict(predictions)
|
predictions = dict(predictions)
|
||||||
|
|||||||
@@ -478,7 +478,6 @@ class DeiTModel(DeiTPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
|||||||
@@ -69,19 +69,11 @@ SPEECH_ENCODER_DECODER_START_DOCSTRING = r"""
|
|||||||
|
|
||||||
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
||||||
Args:
|
Args:
|
||||||
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
inputs (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, feature_dim)`, *optional*):
|
||||||
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
Float values of input raw speech waveform or speech features. Values can be obtained by loading a *.flac* or *.wav* audio file
|
||||||
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
||||||
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should
|
soundfile*). To prepare the array into *inputs*, either the [`Wav2Vec2Processor`] or [`Speech2TextProcessor`] should
|
||||||
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See
|
be used for padding and conversion into a tensor of type *torch.FloatTensor*.
|
||||||
[`Wav2Vec2Processor.__call__`] for details.
|
|
||||||
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
|
|
||||||
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
|
|
||||||
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
|
|
||||||
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
|
|
||||||
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
|
|
||||||
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
|
|
||||||
[`~Speech2TextTokenizer.__call__`]
|
|
||||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
@@ -137,6 +129,19 @@ SPEECH_ENCODER_DECODER_INPUTS_DOCSTRING = r"""
|
|||||||
output_hidden_states (`bool`, *optional*):
|
output_hidden_states (`bool`, *optional*):
|
||||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
||||||
more detail.
|
more detail.
|
||||||
|
input_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
|
Float values of input raw speech waveform. Values can be obtained by loading a *.flac* or *.wav* audio file
|
||||||
|
into an array of type *List[float]* or a *numpy.ndarray*, *e.g.* via the soundfile library (*pip install
|
||||||
|
soundfile*). To prepare the array into *input_values*, the [`Wav2Vec2Processor`] should
|
||||||
|
be used for padding and conversion into a tensor of type *torch.FloatTensor*. See
|
||||||
|
[`Wav2Vec2Processor.__call__`] for details.
|
||||||
|
input_features (`torch.LongTensor` of shape `(batch_size, sequence_length, feature_size)`, *optional*):
|
||||||
|
Float values of fbank features extracted from the raw speech waveform. Raw speech waveform can be obtained
|
||||||
|
by loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a
|
||||||
|
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array
|
||||||
|
into `input_features`, the [`Speech2TextTokenizer`] should be used for extracting
|
||||||
|
the fbank features, padding and conversion into a tensor of type `torch.FloatTensor`. See
|
||||||
|
[`~Speech2TextTokenizer.__call__`]
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
If set to `True`, the model will return a [`~file_utils.Seq2SeqLMOutput`] instead of a
|
If set to `True`, the model will return a [`~file_utils.Seq2SeqLMOutput`] instead of a
|
||||||
plain tuple.
|
plain tuple.
|
||||||
@@ -176,7 +181,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
config_class = SpeechEncoderDecoderConfig
|
config_class = SpeechEncoderDecoderConfig
|
||||||
base_model_prefix = "speech_encoder_decoder"
|
base_model_prefix = "speech_encoder_decoder"
|
||||||
main_input_name = "input_values"
|
main_input_name = "inputs"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -417,8 +422,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values=None,
|
inputs=None,
|
||||||
input_features=None,
|
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
decoder_input_ids=None,
|
decoder_input_ids=None,
|
||||||
decoder_attention_mask=None,
|
decoder_attention_mask=None,
|
||||||
@@ -429,6 +433,8 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
use_cache=None,
|
use_cache=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
input_values=None,
|
||||||
|
input_features=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
@@ -463,7 +469,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
|||||||
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
|
||||||
}
|
}
|
||||||
|
|
||||||
if encoder_outputs is None:
|
if encoder_outputs is None and inputs is None:
|
||||||
if input_values is not None and input_features is not None:
|
if input_values is not None and input_features is not None:
|
||||||
raise ValueError("You cannot specify both input_values and input_features at the same time")
|
raise ValueError("You cannot specify both input_values and input_features at the same time")
|
||||||
elif input_values is not None:
|
elif input_values is not None:
|
||||||
|
|||||||
@@ -507,7 +507,6 @@ class ViTModel(ViTPreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
pixel_values=None,
|
pixel_values=None,
|
||||||
attention_mask=None,
|
|
||||||
head_mask=None,
|
head_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
|
|||||||
@@ -161,11 +161,17 @@ class Seq2SeqTrainer(Trainer):
|
|||||||
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
|
"synced_gpus": True if is_deepspeed_zero3_enabled() else False,
|
||||||
}
|
}
|
||||||
|
|
||||||
model_input_names = self.tokenizer.model_input_names if self.tokenizer is not None else ["input_ids"]
|
# prepare generation inputs
|
||||||
generation_inputs = {k: v for k, v in inputs.items() if k in model_input_names}
|
# some encoder-decoder models can have varying encder's and thus
|
||||||
|
# varying model input names
|
||||||
|
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
|
||||||
|
generation_inputs = inputs[self.model.encoder.main_input_name]
|
||||||
|
else:
|
||||||
|
generation_inputs = inputs[self.model.main_input_name]
|
||||||
|
|
||||||
generated_tokens = self.model.generate(
|
generated_tokens = self.model.generate(
|
||||||
**generation_inputs,
|
generation_inputs,
|
||||||
|
attention_mask=inputs.get("attention_mask", None),
|
||||||
**gen_kwargs,
|
**gen_kwargs,
|
||||||
)
|
)
|
||||||
# in case the batch is shorter than max length, the output should be padded
|
# in case the batch is shorter than max length, the output should be padded
|
||||||
|
|||||||
@@ -1856,7 +1856,7 @@ class GenerationIntegrationTests(unittest.TestCase):
|
|||||||
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
model = GPT2LMHeadModel.from_pretrained("hf-internal-testing/tiny-random-gpt2", max_length=10).to(torch_device)
|
||||||
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
input_ids = tokenizer(article, return_tensors="pt").input_ids.to(torch_device)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
model.generate(input_ids=input_ids, input_values=input_ids)
|
model.generate(input_ids=input_ids, inputs_embeds=input_ids)
|
||||||
|
|
||||||
def test_generate_input_values_as_encoder_kwarg(self):
|
def test_generate_input_values_as_encoder_kwarg(self):
|
||||||
input_values = floats_tensor((2, 250))
|
input_values = floats_tensor((2, 250))
|
||||||
|
|||||||
@@ -64,14 +64,7 @@ class EncoderDecoderMixin:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def check_encoder_decoder_model_from_pretrained_configs(
|
def check_encoder_decoder_model_from_pretrained_configs(
|
||||||
self,
|
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||||
config,
|
|
||||||
attention_mask,
|
|
||||||
decoder_config,
|
|
||||||
decoder_input_ids,
|
|
||||||
decoder_attention_mask,
|
|
||||||
pixel_values=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
encoder_decoder_config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config, decoder_config)
|
||||||
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
self.assertTrue(encoder_decoder_config.decoder.is_decoder)
|
||||||
@@ -84,7 +77,6 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
)
|
)
|
||||||
@@ -94,14 +86,7 @@ class EncoderDecoderMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def check_encoder_decoder_model(
|
def check_encoder_decoder_model(
|
||||||
self,
|
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||||
config,
|
|
||||||
attention_mask,
|
|
||||||
decoder_config,
|
|
||||||
decoder_input_ids,
|
|
||||||
decoder_attention_mask,
|
|
||||||
pixel_values=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
@@ -111,7 +96,6 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -122,7 +106,6 @@ class EncoderDecoderMixin:
|
|||||||
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
|
encoder_outputs = BaseModelOutput(last_hidden_state=outputs_encoder_decoder.encoder_hidden_states[-1])
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
encoder_outputs=encoder_outputs,
|
encoder_outputs=encoder_outputs,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
)
|
)
|
||||||
@@ -134,7 +117,6 @@ class EncoderDecoderMixin:
|
|||||||
def check_encoder_decoder_model_from_pretrained(
|
def check_encoder_decoder_model_from_pretrained(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
attention_mask,
|
|
||||||
decoder_config,
|
decoder_config,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
@@ -148,7 +130,6 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_hidden_states=True,
|
output_hidden_states=True,
|
||||||
@@ -160,14 +141,7 @@ class EncoderDecoderMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def check_save_and_load(
|
def check_save_and_load(
|
||||||
self,
|
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||||
config,
|
|
||||||
attention_mask,
|
|
||||||
decoder_config,
|
|
||||||
decoder_input_ids,
|
|
||||||
decoder_attention_mask,
|
|
||||||
pixel_values=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
@@ -176,7 +150,6 @@ class EncoderDecoderMixin:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = enc_dec_model(
|
outputs = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
)
|
)
|
||||||
@@ -190,7 +163,6 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
after_outputs = enc_dec_model(
|
after_outputs = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
)
|
)
|
||||||
@@ -200,14 +172,7 @@ class EncoderDecoderMixin:
|
|||||||
self.assertLessEqual(max_diff, 1e-5)
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
def check_save_and_load_encoder_decoder_model(
|
def check_save_and_load_encoder_decoder_model(
|
||||||
self,
|
self, config, decoder_config, decoder_input_ids, decoder_attention_mask, pixel_values=None, **kwargs
|
||||||
config,
|
|
||||||
attention_mask,
|
|
||||||
decoder_config,
|
|
||||||
decoder_input_ids,
|
|
||||||
decoder_attention_mask,
|
|
||||||
pixel_values=None,
|
|
||||||
**kwargs
|
|
||||||
):
|
):
|
||||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||||
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
enc_dec_model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||||
@@ -216,7 +181,6 @@ class EncoderDecoderMixin:
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
outputs = enc_dec_model(
|
outputs = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
)
|
)
|
||||||
@@ -233,7 +197,6 @@ class EncoderDecoderMixin:
|
|||||||
|
|
||||||
after_outputs = enc_dec_model(
|
after_outputs = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
)
|
)
|
||||||
@@ -245,7 +208,6 @@ class EncoderDecoderMixin:
|
|||||||
def check_encoder_decoder_model_output_attentions(
|
def check_encoder_decoder_model_output_attentions(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
attention_mask,
|
|
||||||
decoder_config,
|
decoder_config,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
@@ -261,7 +223,6 @@ class EncoderDecoderMixin:
|
|||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -382,13 +343,10 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
|
# for DEiT, the sequence length is equal to the number of patches + 2 (for the [CLS] and distillation tokens)
|
||||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 2
|
|
||||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
|
||||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||||
inputs = {
|
inputs = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
}
|
}
|
||||||
@@ -398,7 +356,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
def check_encoder_decoder_model_output_attentions(
|
def check_encoder_decoder_model_output_attentions(
|
||||||
self,
|
self,
|
||||||
config,
|
config,
|
||||||
attention_mask,
|
|
||||||
decoder_config,
|
decoder_config,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
decoder_attention_mask,
|
decoder_attention_mask,
|
||||||
@@ -414,7 +371,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
enc_dec_model.to(torch_device)
|
enc_dec_model.to(torch_device)
|
||||||
outputs_encoder_decoder = enc_dec_model(
|
outputs_encoder_decoder = enc_dec_model(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
attention_mask=attention_mask,
|
|
||||||
decoder_input_ids=decoder_input_ids,
|
decoder_input_ids=decoder_input_ids,
|
||||||
decoder_attention_mask=decoder_attention_mask,
|
decoder_attention_mask=decoder_attention_mask,
|
||||||
output_attentions=True,
|
output_attentions=True,
|
||||||
@@ -463,7 +419,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
|
encoder_config_and_inputs = deit_model_tester.prepare_config_and_inputs()
|
||||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
config, pixel_values, _ = encoder_config_and_inputs
|
config, pixel_values, _ = encoder_config_and_inputs
|
||||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
|
||||||
(
|
(
|
||||||
decoder_config,
|
decoder_config,
|
||||||
decoder_input_ids,
|
decoder_input_ids,
|
||||||
@@ -481,7 +436,6 @@ class DeiT2RobertaModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
return {
|
return {
|
||||||
"config": config,
|
"config": config,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"attention_mask": input_mask,
|
|
||||||
"decoder_config": decoder_config,
|
"decoder_config": decoder_config,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"decoder_token_type_ids": decoder_token_type_ids,
|
"decoder_token_type_ids": decoder_token_type_ids,
|
||||||
@@ -509,13 +463,10 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
|
# for ViT, the sequence length is equal to the number of patches + 1 (for the [CLS] token)
|
||||||
seq_len = (model.encoder.config.image_size // model.encoder.config.patch_size) ** 2 + 1
|
|
||||||
attention_mask = random_attention_mask([batch_size, seq_len])
|
|
||||||
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
decoder_input_ids = ids_tensor([batch_size, 4], model.decoder.config.vocab_size)
|
||||||
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
decoder_attention_mask = random_attention_mask([batch_size, 4])
|
||||||
inputs = {
|
inputs = {
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
}
|
}
|
||||||
@@ -534,7 +485,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
decoder_config_and_inputs = bert_model_tester.prepare_config_and_inputs_for_decoder()
|
||||||
|
|
||||||
config, pixel_values, _ = encoder_config_and_inputs
|
config, pixel_values, _ = encoder_config_and_inputs
|
||||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
|
||||||
|
|
||||||
(
|
(
|
||||||
decoder_config,
|
decoder_config,
|
||||||
@@ -553,7 +503,6 @@ class ViT2BertModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
return {
|
return {
|
||||||
"config": config,
|
"config": config,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"attention_mask": input_mask,
|
|
||||||
"decoder_config": decoder_config,
|
"decoder_config": decoder_config,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"decoder_token_type_ids": decoder_token_type_ids,
|
"decoder_token_type_ids": decoder_token_type_ids,
|
||||||
@@ -580,7 +529,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
|
||||||
config, pixel_values, _ = encoder_config_and_inputs
|
config, pixel_values, _ = encoder_config_and_inputs
|
||||||
input_mask = None # TODO add once attention_mask is supported for vision models
|
|
||||||
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
|
(decoder_config, decoder_input_ids, decoder_attention_mask, _) = decoder_config_and_inputs
|
||||||
|
|
||||||
# make sure that cross attention layers are added
|
# make sure that cross attention layers are added
|
||||||
@@ -590,7 +538,6 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
return {
|
return {
|
||||||
"config": config,
|
"config": config,
|
||||||
"pixel_values": pixel_values,
|
"pixel_values": pixel_values,
|
||||||
"attention_mask": input_mask,
|
|
||||||
"decoder_config": decoder_config,
|
"decoder_config": decoder_config,
|
||||||
"decoder_input_ids": decoder_input_ids,
|
"decoder_input_ids": decoder_input_ids,
|
||||||
"decoder_attention_mask": decoder_attention_mask,
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
|||||||
Reference in New Issue
Block a user