Kill model archive maps (#4636)
* Kill model archive maps * Fixup * Also kill model_archive_map for MaskedBertPreTrainedModel * Unhook config_archive_map * Tokenizers: align with model id changes * make style && make quality * Fix CI
This commit is contained in:
@@ -34,26 +34,11 @@ from tqdm import tqdm, trange
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertModel,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertModel,
|
||||
DistilBertTokenizer,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
MMBTConfig,
|
||||
MMBTForClassification,
|
||||
RobertaConfig,
|
||||
RobertaModel,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMModel,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetModel,
|
||||
XLNetTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
|
||||
@@ -67,23 +52,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertModel, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMModel, XLMTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
@@ -351,19 +319,12 @@ def main():
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@@ -385,7 +346,7 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
@@ -526,18 +487,14 @@ def main():
|
||||
# Setup model
|
||||
labels = get_mmimdb_labels()
|
||||
num_labels = len(labels)
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
transformer_config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
transformer = model_class.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
transformer = AutoModel.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir
|
||||
)
|
||||
img_encoder = ImageEncoder(args)
|
||||
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
||||
@@ -583,13 +540,12 @@ def main():
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
|
||||
Reference in New Issue
Block a user