From e248e9b0423b3794baa13330ba1a0038ec1bd7ac Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 26 Oct 2021 13:08:18 +0200 Subject: [PATCH] up (#14154) --- .../run_speech_recognition_ctc.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py index 2b1a5810f7..ef54948d69 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_ctc.py @@ -187,6 +187,13 @@ class DataTrainingArguments: "so that the cached datasets can consequently be loaded in distributed training" }, ) + use_auth_token: Optional[bool] = field( + default=False, + metadata={ + "help": "If :obj:`True`, will use the token generated when running" + ":obj:`transformers-cli logiin as HTTP bearer authorization for remote files." + }, + ) @dataclass @@ -408,7 +415,9 @@ def main(): # one local process can concurrently download model & vocab. # load config - config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir) + config = AutoConfig.from_pretrained( + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token + ) # tokenizer is defined by `tokenizer_class` if present in config else by `model_type` config_for_tokenizer = config if config.tokenizer_class is not None else None @@ -422,9 +431,10 @@ def main(): unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|", + use_auth_token=data_args.use_auth_token, ) feature_extractor = AutoFeatureExtractor.from_pretrained( - model_args.model_name_or_path, cache_dir=model_args.cache_dir + model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_auth_token=data_args.use_auth_token ) processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) @@ -447,7 +457,10 @@ def main(): # create model model = AutoModelForCTC.from_pretrained( - model_args.model_name_or_path, cache_dir=model_args.cache_dir, config=config + model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + config=config, + use_auth_token=data_args.use_auth_token, ) # freeze encoder