From 16469fedbdfb67c05d8c8018f895dc0e046b1b40 Mon Sep 17 00:00:00 2001 From: Sam Shleifer Date: Thu, 16 Apr 2020 15:02:43 -0400 Subject: [PATCH] [PretrainedTokenizer] Factor out tensor conversion method (#3777) --- src/transformers/tokenization_utils.py | 59 ++++++++++++++------------ 1 file changed, 31 insertions(+), 28 deletions(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index b45cbd2ea0..bdfda4f835 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -517,7 +517,7 @@ class PreTrainedTokenizer(SpecialTokensMixin): self.max_len = max_len if max_len is not None else int(1e12) - # Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed. + # Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed. self.padding_side = kwargs.pop("padding_side", self.padding_side) self.model_input_names = kwargs.pop("model_input_names", self.model_input_names) @@ -1447,35 +1447,38 @@ class PreTrainedTokenizer(SpecialTokensMixin): if return_tensors is not None: - # Do the tensor conversion in batch - for key, value in batch_outputs.items(): - if return_tensors == "tf" and is_tf_available(): - try: - batch_outputs[key] = tf.constant(value) - except ValueError: - if None in [item for sequence in value for item in sequence]: - raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) - else: - raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) - elif return_tensors == "pt" and is_torch_available(): - try: - batch_outputs[key] = torch.tensor(value) - except ValueError: - raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) - except RuntimeError: - if None in [item for sequence in value for item in sequence]: - raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) - else: - raise - elif return_tensors is not None: - logger.warning( - "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( - return_tensors - ) - ) - + self.convert_to_tensors_(batch_outputs, return_tensors) return BatchEncoding(batch_outputs) + def convert_to_tensors_(self, batch_outputs: dict, return_tensors: str) -> None: + # Do the tensor conversion in batch + for key, value in batch_outputs.items(): + if return_tensors == "tf" and is_tf_available(): + try: + batch_outputs[key] = tf.constant(value) + except ValueError: + if None in [item for sequence in value for item in sequence]: + raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) + else: + raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) + elif return_tensors == "pt" and is_torch_available(): + try: + batch_outputs[key] = torch.tensor(value) + except ValueError: + raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG) + except RuntimeError: + if None in [item for sequence in value for item in sequence]: + raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG) + else: + raise + + elif return_tensors is not None: + logger.warning( + "Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format( + return_tensors + ) + ) + def prepare_for_model( self, ids: List[int],