From 280db79ac139eff31962a56006b34a9f42886834 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 14 Jan 2021 07:57:58 -0500 Subject: [PATCH] BatchEncoding.to with device with tests (#9584) --- src/transformers/tokenization_utils_base.py | 8 +++++++- tests/test_tokenization_common.py | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 11d4dce541..f3161c710b 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -65,6 +65,12 @@ def _is_torch(x): return isinstance(x, torch.Tensor) +def _is_torch_device(x): + import torch + + return isinstance(x, torch.device) + + def _is_tensorflow(x): import tensorflow as tf @@ -801,7 +807,7 @@ class BatchEncoding(UserDict): # This check catches things like APEX blindly calling "to" on all inputs to a module # Otherwise it passes the casts down and casts the LongTensor containing the token idxs # into a HalfTensor - if isinstance(device, str) or isinstance(device, torch.device) or isinstance(device, int): + if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int): self.data = {k: v.to(device=device) for k, v in self.data.items()} else: logger.warning( diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 9462c86b7b..82caaccbe5 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -1704,6 +1704,10 @@ class TokenizerTesterMixin: first_ten_tokens = list(tokenizer.get_vocab().keys())[:10] sequence = " ".join(first_ten_tokens) encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt") + + # Ensure that the BatchEncoding.to() method works. + encoded_sequence.to(model.device) + batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt") # This should not fail