dilbert -> distilbert

This commit is contained in:
thomwolf
2019-08-28 13:59:42 +02:00
parent c9bce1811c
commit 912a377e90
15 changed files with 144 additions and 144 deletions

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Training DilBERT.
Training DistilBERT.
"""
import os
import argparse
@@ -24,7 +24,7 @@ import numpy as np
import torch
from pytorch_transformers import BertTokenizer, BertForMaskedLM
from pytorch_transformers import DilBertForMaskedLM, DilBertConfig
from pytorch_transformers import DistilBertForMaskedLM, DistilBertConfig
from distiller import Distiller
from utils import git_log, logger, init_gpu_params, set_seed
@@ -201,13 +201,13 @@ def main():
assert os.path.isfile(os.path.join(args.from_pretrained_config))
logger.info(f'Loading pretrained weights from {args.from_pretrained_weights}')
logger.info(f'Loading pretrained config from {args.from_pretrained_config}')
stu_architecture_config = DilBertConfig.from_json_file(args.from_pretrained_config)
student = DilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
stu_architecture_config = DistilBertConfig.from_json_file(args.from_pretrained_config)
student = DistilBertForMaskedLM.from_pretrained(args.from_pretrained_weights,
config=stu_architecture_config)
else:
args.vocab_size_or_config_json_file = args.vocab_size
stu_architecture_config = DilBertConfig(**vars(args))
student = DilBertForMaskedLM(stu_architecture_config)
stu_architecture_config = DistilBertConfig(**vars(args))
student = DistilBertForMaskedLM(stu_architecture_config)
if args.n_gpu > 0: