From 6dacc79d395bd41e0ef76c2a043c2ef90cc79925 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 5 Jul 2019 15:11:59 +0200 Subject: [PATCH] fix python2 tests --- pytorch_transformers/tests/tokenization_tests_commons.py | 6 ++---- pytorch_transformers/tokenization_utils.py | 6 +++++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index e8f7ee7a25..876f7747be 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -12,9 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function +from __future__ import absolute_import, division, print_function, unicode_literals import os import sys @@ -47,7 +45,7 @@ def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, * def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs): tokenizer = tokenizer_class(*inputs, **kwargs) - text = "Munich and Berlin are nice cities" + text = u"Munich and Berlin are nice cities" filename = u"/tmp/tokenizer.bin" subwords = tokenizer.tokenize(text) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 98a2968539..c6f08c41ae 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -101,8 +101,12 @@ class PreTrainedTokenizer(object): max_len = cls.max_model_input_sizes[pretrained_model_name_or_path] kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len) + # Merge resolved_vocab_files arguments in kwargs. + for args_name, file_path in resolved_vocab_files.items(): + kwargs[args_name] = file_path + # Instantiate tokenizer. - tokenizer = cls(*inputs, **resolved_vocab_files, **kwargs) + tokenizer = cls(*inputs, **kwargs) return tokenizer