Fix tokenizer saving during training with Trainer (#12806)
* add test in trainer and test tokenizer saving wi th trainer * quality * reverse trainer changes * replace test in test_trainer by a test for all the tokenizers * format * add can_save_slow_tokenizer attribute to all tokenizers * fix Herbert * format * Change comment in error * add comments and a new assert * Update src/transformers/models/albert/tokenization_albert_fast.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change ValueError barthez * change ValueError BigBird * change ValueError Camembert * change ValueError Mbart50 * change ValueError Pegasus * change ValueError ReFormer * change ValueError T5 * change ValueError RoBERTa * XLNET fast * Update tests/test_tokenization_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change `assert` into `self.assertIn` * format Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -38,6 +38,8 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
SpecialTokensMixin,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
@@ -56,6 +58,10 @@ from transformers.testing_utils import (
|
||||
from transformers.tokenization_utils import AddedToken
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
|
||||
|
||||
@@ -3389,6 +3395,27 @@ class TokenizerTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_saving_tokenizer_trainer(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save the fast tokenizer files in a temporary directory
|
||||
tokenizer_old = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs, use_fast=True)
|
||||
tokenizer_old.save_pretrained(tmp_dir, legacy_format=False) # save only fast version
|
||||
|
||||
# Initialize toy model for the trainer
|
||||
model = nn.Module()
|
||||
|
||||
# Load tokenizer from a folder without legacy files
|
||||
tokenizer = self.rust_tokenizer_class.from_pretrained(tmp_dir)
|
||||
training_args = TrainingArguments(output_dir=tmp_dir, do_train=True, no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, tokenizer=tokenizer)
|
||||
|
||||
# Should not raise an error
|
||||
trainer.save_model(os.path.join(tmp_dir, "checkpoint"))
|
||||
self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint")))
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user