[TokenizationRoformerFast] Fix the save and loading (#28527)
* cleanup * add a test * update the test * style * revert part that allows to pickle the tokenizer
This commit is contained in:
@@ -122,15 +122,19 @@ class RoFormerTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
pre_tok_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
|
||||||
if (
|
if (
|
||||||
pre_tok_state.get("lowercase", do_lower_case) != do_lower_case
|
normalizer_state.get("lowercase", do_lower_case) != do_lower_case
|
||||||
or pre_tok_state.get("strip_accents", strip_accents) != strip_accents
|
or normalizer_state.get("strip_accents", strip_accents) != strip_accents
|
||||||
):
|
):
|
||||||
pre_tok_class = getattr(normalizers, pre_tok_state.pop("type"))
|
normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
|
||||||
pre_tok_state["lowercase"] = do_lower_case
|
normalizer_state["lowercase"] = do_lower_case
|
||||||
pre_tok_state["strip_accents"] = strip_accents
|
normalizer_state["strip_accents"] = strip_accents
|
||||||
self.backend_tokenizer.normalizer = pre_tok_class(**pre_tok_state)
|
self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
|
||||||
|
|
||||||
|
# Make sure we correctly set the custom PreTokenizer
|
||||||
|
vocab = self.backend_tokenizer.get_vocab()
|
||||||
|
self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
|
||||||
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
|
from transformers import RoFormerTokenizer, RoFormerTokenizerFast
|
||||||
@@ -71,6 +72,12 @@ class RoFormerTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def test_training_new_tokenizer_with_special_tokens_change(self):
|
def test_training_new_tokenizer_with_special_tokens_change(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# can't serialise custom PreTokenizer
|
|
||||||
def test_save_slow_from_fast_and_reload_fast(self):
|
def test_save_slow_from_fast_and_reload_fast(self):
|
||||||
pass
|
for cls in [RoFormerTokenizer, RoFormerTokenizerFast]:
|
||||||
|
original = cls.from_pretrained("alchemab/antiberta2")
|
||||||
|
self.assertEqual(original.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2])
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
original.save_pretrained(tmp_dir)
|
||||||
|
new = cls.from_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(new.encode("生活的真谛是"), [1, 4, 4, 4, 4, 4, 4, 2])
|
||||||
|
|||||||
Reference in New Issue
Block a user