Tokenizers: ability to load from model subfolder (#8586)
* <small>tiny typo</small> * Tokenizers: ability to load from model subfolder * use subfolder for local files as well * Uniformize model shortcut name => model id * from s3 => from huggingface.co Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
This commit is contained in:
@@ -238,10 +238,9 @@ class RagPreTrainedModel(PreTrainedModel):
|
||||
question_encoder_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||
Information necessary to initiate the question encoder. Can be either:
|
||||
|
||||
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
|
||||
``bert-base-uncased``.
|
||||
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
|
||||
``dbmdz/bert-base-german-cased``.
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
|
||||
@@ -252,10 +251,9 @@ class RagPreTrainedModel(PreTrainedModel):
|
||||
generator_pretrained_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||
Information necessary to initiate the generator. Can be either:
|
||||
|
||||
- A string with the `shortcut name` of a pretrained model to load from cache or download, e.g.,
|
||||
``bert-base-uncased``.
|
||||
- A string with the `identifier name` of a pretrained model that was user-uploaded to our S3, e.g.,
|
||||
``dbmdz/bert-base-german-cased``.
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.PreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `tensorflow index checkpoint file` (e.g, ``./tf_model/model.ckpt.index``). In
|
||||
|
||||
@@ -49,10 +49,12 @@ class RagTokenizer:
|
||||
if config is None:
|
||||
config = RagConfig.from_pretrained(pretrained_model_name_or_path)
|
||||
|
||||
question_encoder_path = os.path.join(pretrained_model_name_or_path, "question_encoder_tokenizer")
|
||||
generator_path = os.path.join(pretrained_model_name_or_path, "generator_tokenizer")
|
||||
question_encoder = AutoTokenizer.from_pretrained(question_encoder_path, config=config.question_encoder)
|
||||
generator = AutoTokenizer.from_pretrained(generator_path, config=config.generator)
|
||||
question_encoder = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path, config=config.question_encoder, subfolder="question_encoder_tokenizer"
|
||||
)
|
||||
generator = AutoTokenizer.from_pretrained(
|
||||
pretrained_model_name_or_path, config=config.generator, subfolder="generator_tokenizer"
|
||||
)
|
||||
return cls(question_encoder=question_encoder, generator=generator)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user