[Rag] Fix loading of pretrained Rag Tokenizer (#7756)
* fix rag * Update tokenizer save_pretrained Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
2d4e928d97
commit
82b09a8481
@@ -1637,9 +1637,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
if special_tokens_map_file is not None:
|
||||
with open(special_tokens_map_file, encoding="utf-8") as special_tokens_map_handle:
|
||||
special_tokens_map = json.load(special_tokens_map_handle)
|
||||
|
||||
special_tokens_map = convert_added_tokens(special_tokens_map)
|
||||
for key, value in special_tokens_map.items():
|
||||
if isinstance(value, dict):
|
||||
value = AddedToken(**value)
|
||||
elif isinstance(value, list):
|
||||
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
|
||||
setattr(tokenizer, key, value)
|
||||
|
||||
# Add supplementary tokens.
|
||||
@@ -1706,23 +1708,25 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
tokenizer_config.pop(file_id, None)
|
||||
|
||||
# Sanitize AddedTokens
|
||||
def convert_added_tokens(obj: Union[AddedToken, Any]):
|
||||
def convert_added_tokens(obj: Union[AddedToken, Any], add_type_field=True):
|
||||
if isinstance(obj, AddedToken):
|
||||
out = obj.__getstate__()
|
||||
out["__type"] = "AddedToken"
|
||||
if add_type_field:
|
||||
out["__type"] = "AddedToken"
|
||||
return out
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
return list(convert_added_tokens(o) for o in obj)
|
||||
return list(convert_added_tokens(o, add_type_field=add_type_field) for o in obj)
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_added_tokens(v) for k, v in obj.items()}
|
||||
return {k: convert_added_tokens(v, add_type_field=add_type_field) for k, v in obj.items()}
|
||||
return obj
|
||||
|
||||
tokenizer_config = convert_added_tokens(tokenizer_config)
|
||||
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
|
||||
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
|
||||
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
|
||||
# Sanitize AddedTokens in special_tokens_map
|
||||
write_dict = convert_added_tokens(self.special_tokens_map_extended)
|
||||
write_dict = convert_added_tokens(self.special_tokens_map_extended, add_type_field=False)
|
||||
with open(special_tokens_map_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(write_dict, ensure_ascii=False))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user