[Rag] Fix loading of pretrained Rag Tokenizer (#7756)

* fix rag

* Update tokenizer save_pretrained

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-10-13 14:34:22 +02:00
committed by GitHub
parent 2d4e928d97
commit 82b09a8481
2 changed files with 59 additions and 9 deletions

View File

@@ -1637,9 +1637,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
if special_tokens_map_file is not None:
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
special_tokens_map = json.load(special_tokens_map_handle)
special_tokens_map = convert_added_tokens(special_tokens_map)
for key, value in special_tokens_map.items():
if isinstance(value, dict):
value = AddedToken(**value)
elif isinstance(value, list):
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
setattr(tokenizer, key, value)
# Add supplementary tokens.
@@ -1706,23 +1708,25 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
tokenizer_config.pop(file_id, None)
# Sanitize AddedTokens
def convert_added_tokens(obj: Union[AddedToken, Any]):
def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
if isinstance(obj, AddedToken):
out = obj.__getstate__()
out["__type"] = "AddedToken"
if add_type_field:
out["__type"] = "AddedToken"
return out
elif isinstance(obj, (list, tuple)):
return list(convert_added_tokens(o) for o in obj)
return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj)
elif isinstance(obj, dict):
return {k: convert_added_tokens(v) for k, v in obj.items()}
return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()}
return obj
tokenizer_config = convert_added_tokens(tokenizer_config)
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
# Sanitize AddedTokens in special_tokens_map
write_dict = convert_added_tokens(self.special_tokens_map_extended)
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
f.write(json.dumps(write_dict, ensure_ascii=False))