Saving Trainer.collator.tokenizer in when Trainer.processing_class is None (#36552)
* feat: Saving tokenizer in collator when processing_class is None * chore: Style issue * chore: Typo * dbg: Check why test failed * dbg: Remove logics and another test failed which successed before, so should be the stablibility issue * test: Init unit-test * chore: Style * chore: Add err log * fix: Case * Update tests/trainer/test_trainer.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * chore: Try to use get_regression_trainer * fix: Impl and style * fix: Style * fix: Case * fix: Import err * fix: Missed import * fix: Import block un-sorted problem * fix: Try another tokenizer * fix: Test logic * chore: Light updates * chore: Reformat --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -3992,6 +3992,13 @@ class Trainer:
|
|||||||
|
|
||||||
if self.processing_class is not None:
|
if self.processing_class is not None:
|
||||||
self.processing_class.save_pretrained(output_dir)
|
self.processing_class.save_pretrained(output_dir)
|
||||||
|
elif (
|
||||||
|
self.data_collator is not None
|
||||||
|
and hasattr(self.data_collator, "tokenizer")
|
||||||
|
and self.data_collator.tokenizer is not None
|
||||||
|
):
|
||||||
|
logger.info("Saving Trainer.data_collator.tokenizer by default as Trainer.processing_class is `None`")
|
||||||
|
self.data_collator.tokenizer.save_pretrained(output_dir)
|
||||||
|
|
||||||
# Good practice: save your training arguments together with the trained model
|
# Good practice: save your training arguments together with the trained model
|
||||||
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ import unittest
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
from typing import Any, Dict, List
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -46,6 +46,7 @@ from transformers import (
|
|||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
|
default_data_collator,
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
get_polynomial_decay_schedule_with_warmup,
|
get_polynomial_decay_schedule_with_warmup,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
@@ -2975,6 +2976,24 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
tmp_dir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
|
tmp_dir, 5, int(self.n_epochs * 64 / self.batch_size), False, safe_weights=save_safetensors
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_save_collator_tokenizer_by_default(self):
|
||||||
|
class FakeCollator:
|
||||||
|
def __init__(self):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
self.tokenizer.add_tokens(["<NEW_TOKEN1>", "<NEW_TOKEN2>"])
|
||||||
|
|
||||||
|
def __call__(self, features: List[Any], return_tensors="pt") -> Dict[str, Any]:
|
||||||
|
return default_data_collator(features, return_tensors)
|
||||||
|
|
||||||
|
data_collator = FakeCollator()
|
||||||
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
output_dir=tmp_dir, save_steps=5, save_safetensors=True, data_collator=data_collator
|
||||||
|
)
|
||||||
|
trainer.train()
|
||||||
|
loaded_tokenizer = AutoTokenizer.from_pretrained(os.path.join(tmp_dir, os.listdir(tmp_dir)[0]))
|
||||||
|
assert len(loaded_tokenizer) == len(trainer.data_collator.tokenizer), "Failed to load updated tokenizer"
|
||||||
|
|
||||||
def test_load_best_model_with_save(self):
|
def test_load_best_model_with_save(self):
|
||||||
tmp_dir = self.get_auto_remove_tmp_dir()
|
tmp_dir = self.get_auto_remove_tmp_dir()
|
||||||
trainer = get_regression_trainer(
|
trainer = get_regression_trainer(
|
||||||
|
|||||||
Reference in New Issue
Block a user