[PretrainedTokenizer] Factor out tensor conversion method (#3777)
This commit is contained in:
@@ -517,7 +517,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
|||||||
|
|
||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
|
|
||||||
# Padding side is right by default and over-riden in subclasses. If specified in the kwargs, it is changed.
|
# Padding side is right by default and overridden in subclasses. If specified in the kwargs, it is changed.
|
||||||
self.padding_side = kwargs.pop("padding_side", self.padding_side)
|
self.padding_side = kwargs.pop("padding_side", self.padding_side)
|
||||||
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
|
self.model_input_names = kwargs.pop("model_input_names", self.model_input_names)
|
||||||
|
|
||||||
@@ -1447,35 +1447,38 @@ class PreTrainedTokenizer(SpecialTokensMixin):
|
|||||||
|
|
||||||
if return_tensors is not None:
|
if return_tensors is not None:
|
||||||
|
|
||||||
# Do the tensor conversion in batch
|
self.convert_to_tensors_(batch_outputs, return_tensors)
|
||||||
for key, value in batch_outputs.items():
|
|
||||||
if return_tensors == "tf" and is_tf_available():
|
|
||||||
try:
|
|
||||||
batch_outputs[key] = tf.constant(value)
|
|
||||||
except ValueError:
|
|
||||||
if None in [item for sequence in value for item in sequence]:
|
|
||||||
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
|
|
||||||
else:
|
|
||||||
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
|
|
||||||
elif return_tensors == "pt" and is_torch_available():
|
|
||||||
try:
|
|
||||||
batch_outputs[key] = torch.tensor(value)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
|
|
||||||
except RuntimeError:
|
|
||||||
if None in [item for sequence in value for item in sequence]:
|
|
||||||
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
elif return_tensors is not None:
|
|
||||||
logger.warning(
|
|
||||||
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
|
||||||
return_tensors
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return BatchEncoding(batch_outputs)
|
return BatchEncoding(batch_outputs)
|
||||||
|
|
||||||
|
def convert_to_tensors_(self, batch_outputs: dict, return_tensors: str) -> None:
|
||||||
|
# Do the tensor conversion in batch
|
||||||
|
for key, value in batch_outputs.items():
|
||||||
|
if return_tensors == "tf" and is_tf_available():
|
||||||
|
try:
|
||||||
|
batch_outputs[key] = tf.constant(value)
|
||||||
|
except ValueError:
|
||||||
|
if None in [item for sequence in value for item in sequence]:
|
||||||
|
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
|
||||||
|
else:
|
||||||
|
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
|
||||||
|
elif return_tensors == "pt" and is_torch_available():
|
||||||
|
try:
|
||||||
|
batch_outputs[key] = torch.tensor(value)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(self.UNEVEN_SEQUENCES_FOR_BATCH_MSG)
|
||||||
|
except RuntimeError:
|
||||||
|
if None in [item for sequence in value for item in sequence]:
|
||||||
|
raise ValueError(self.NO_PAD_TOKEN_FOR_BATCH_MSG)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
elif return_tensors is not None:
|
||||||
|
logger.warning(
|
||||||
|
"Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(
|
||||||
|
return_tensors
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_for_model(
|
def prepare_for_model(
|
||||||
self,
|
self,
|
||||||
ids: List[int],
|
ids: List[int],
|
||||||
|
|||||||
Reference in New Issue
Block a user