[examples] Use AutoModels in more examples
This commit is contained in:
@@ -38,28 +38,15 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
CONFIG_MAPPING,
|
||||
MODEL_WITH_LM_HEAD_MAPPING,
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForMaskedLM,
|
||||
BertTokenizer,
|
||||
CamembertConfig,
|
||||
CamembertForMaskedLM,
|
||||
CamembertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForMaskedLM,
|
||||
DistilBertTokenizer,
|
||||
GPT2Config,
|
||||
GPT2LMHeadModel,
|
||||
GPT2Tokenizer,
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTLMHeadModel,
|
||||
OpenAIGPTTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelWithLMHead,
|
||||
AutoTokenizer,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizer,
|
||||
RobertaConfig,
|
||||
RobertaForMaskedLM,
|
||||
RobertaTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
|
||||
@@ -73,14 +60,8 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
|
||||
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
|
||||
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
|
||||
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
|
||||
}
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
class TextDataset(Dataset):
|
||||
@@ -693,23 +674,21 @@ def main():
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Barrier to make sure only the first process in distributed training download model & vocab
|
||||
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
|
||||
if args.config_name:
|
||||
config = config_class.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(args.config_name, cache_dir=args.cache_dir)
|
||||
elif args.model_name_or_path:
|
||||
config = config_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
else:
|
||||
config = config_class()
|
||||
config = CONFIG_MAPPING[args.model_type]()
|
||||
|
||||
if args.tokenizer_name:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, cache_dir=args.cache_dir)
|
||||
elif args.model_name_or_path:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You are instantiating a new {} tokenizer. This is not supported, but you can do it from another script, save it,"
|
||||
"and load it from here, using --tokenizer_name".format(tokenizer_class.__name__)
|
||||
"and load it from here, using --tokenizer_name".format(AutoTokenizer.__name__)
|
||||
)
|
||||
|
||||
if args.block_size <= 0:
|
||||
@@ -719,7 +698,7 @@ def main():
|
||||
args.block_size = min(args.block_size, tokenizer.max_len)
|
||||
|
||||
if args.model_name_or_path:
|
||||
model = model_class.from_pretrained(
|
||||
model = AutoModelWithLMHead.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
@@ -727,7 +706,7 @@ def main():
|
||||
)
|
||||
else:
|
||||
logger.info("Training new model from scratch")
|
||||
model = model_class(config=config)
|
||||
model = AutoModelWithLMHead(config=config)
|
||||
|
||||
model.to(args.device)
|
||||
|
||||
@@ -768,8 +747,8 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model = AutoModelWithLMHead.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
@@ -786,7 +765,7 @@ def main():
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = AutoModelWithLMHead.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
|
||||
Reference in New Issue
Block a user