From 9d8fd2d40e225781e2a5663c6e6c93fd5bc89293 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Wed, 15 Jan 2020 17:36:52 +0000 Subject: [PATCH] tokenizer.save_pretrained: only save file if non-empty --- src/transformers/configuration_auto.py | 2 +- src/transformers/tokenization_utils.py | 8 +++----- tests/test_tokenization_auto.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/transformers/configuration_auto.py b/src/transformers/configuration_auto.py index 10cfe96929..14f5f4f3da 100644 --- a/src/transformers/configuration_auto.py +++ b/src/transformers/configuration_auto.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Auto Model class. """ +""" Auto Config class. """ import logging diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index f05ad23654..fb8c0f0ff6 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -513,12 +513,10 @@ class PreTrainedTokenizer(object): with open(special_tokens_map_file, "w", encoding="utf-8") as f: f.write(json.dumps(self.special_tokens_map, ensure_ascii=False)) - with open(added_tokens_file, "w", encoding="utf-8") as f: - if self.added_tokens_encoder: + if len(self.added_tokens_encoder) > 0: + with open(added_tokens_file, "w", encoding="utf-8") as f: out_str = json.dumps(self.added_tokens_encoder, ensure_ascii=False) - else: - out_str = "{}" - f.write(out_str) + f.write(out_str) vocab_files = self.save_vocabulary(save_directory) diff --git a/tests/test_tokenization_auto.py b/tests/test_tokenization_auto.py index 5581a1199b..cd7187c4f2 100644 --- a/tests/test_tokenization_auto.py +++ b/tests/test_tokenization_auto.py @@ -33,13 +33,13 @@ class AutoTokenizerTest(unittest.TestCase): # @slow def test_tokenizer_from_pretrained(self): logging.basicConfig(level=logging.INFO) - for model_name in [x for x in BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() if "japanese" not in x]: + for model_name in (x for x in BERT_PRETRAINED_CONFIG_ARCHIVE_MAP.keys() if "japanese" not in x): tokenizer = AutoTokenizer.from_pretrained(model_name) self.assertIsNotNone(tokenizer) self.assertIsInstance(tokenizer, BertTokenizer) self.assertGreater(len(tokenizer), 0) - for model_name in list(GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys())[:1]: + for model_name in GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP.keys(): tokenizer = AutoTokenizer.from_pretrained(model_name) self.assertIsNotNone(tokenizer) self.assertIsInstance(tokenizer, GPT2Tokenizer)