Kill model archive maps (#4636)
* Kill model archive maps * Fixup * Also kill model_archive_map for MaskedBertPreTrainedModel * Unhook config_archive_map * Tokenizers: align with model id changes * make style && make quality * Fix CI
This commit is contained in:
@@ -65,13 +65,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
@@ -389,7 +382,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
|
||||
@@ -34,26 +34,11 @@ from tqdm import tqdm, trange
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
AlbertConfig,
|
||||
AlbertModel,
|
||||
AlbertTokenizer,
|
||||
BertConfig,
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertModel,
|
||||
DistilBertTokenizer,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoTokenizer,
|
||||
MMBTConfig,
|
||||
MMBTForClassification,
|
||||
RobertaConfig,
|
||||
RobertaModel,
|
||||
RobertaTokenizer,
|
||||
XLMConfig,
|
||||
XLMModel,
|
||||
XLMTokenizer,
|
||||
XLNetConfig,
|
||||
XLNetModel,
|
||||
XLNetTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_image_transforms, get_mmimdb_labels
|
||||
@@ -67,23 +52,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
),
|
||||
(),
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertModel, BertTokenizer),
|
||||
"xlnet": (XLNetConfig, XLNetModel, XLNetTokenizer),
|
||||
"xlm": (XLMConfig, XLMModel, XLMTokenizer),
|
||||
"roberta": (RobertaConfig, RobertaModel, RobertaTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertModel, DistilBertTokenizer),
|
||||
"albert": (AlbertConfig, AlbertModel, AlbertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
@@ -351,19 +319,12 @@ def main():
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .jsonl files for MMIMDB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@@ -385,7 +346,7 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
@@ -526,18 +487,14 @@ def main():
|
||||
# Setup model
|
||||
labels = get_mmimdb_labels()
|
||||
num_labels = len(labels)
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
transformer_config = config_class.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
transformer_config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
transformer = model_class.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir if args.cache_dir else None
|
||||
transformer = AutoModel.from_pretrained(
|
||||
args.model_name_or_path, config=transformer_config, cache_dir=args.cache_dir
|
||||
)
|
||||
img_encoder = ImageEncoder(args)
|
||||
config = MMBTConfig(transformer_config, num_labels=num_labels)
|
||||
@@ -583,13 +540,12 @@ def main():
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = MMBTForClassification(config, transformer, img_encoder)
|
||||
model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME)))
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
|
||||
@@ -31,14 +31,8 @@ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, Tenso
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForMultipleChoice,
|
||||
BertTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import WEIGHTS_NAME, AdamW, AutoConfig, AutoTokenizer, get_linear_schedule_with_warmup
|
||||
from transformers.modeling_auto import AutoModelForMultipleChoice
|
||||
|
||||
|
||||
try:
|
||||
@@ -49,12 +43,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in [BertConfig]), ())
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForMultipleChoice, BertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
class SwagExample(object):
|
||||
"""A single training/test example for the SWAG dataset."""
|
||||
@@ -492,19 +480,12 @@ def main():
|
||||
required=True,
|
||||
help="SWAG csv for predictions. E.g., val.csv or test.csv",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
@@ -536,9 +517,6 @@ def main():
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
@@ -652,13 +630,9 @@ def main():
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, do_lower_case=args.do_lower_case
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(args.config_name if args.config_name else args.model_name_or_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,)
|
||||
model = AutoModelForMultipleChoice.from_pretrained(
|
||||
args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config
|
||||
)
|
||||
|
||||
@@ -694,8 +668,8 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model = AutoModelForMultipleChoice.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
|
||||
@@ -718,8 +692,8 @@ def main():
|
||||
for checkpoint in checkpoints:
|
||||
# Reload the model
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
tokenizer = tokenizer_class.from_pretrained(checkpoint)
|
||||
model = AutoModelForMultipleChoice.from_pretrained(checkpoint)
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluate
|
||||
|
||||
@@ -67,9 +67,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
@@ -505,7 +502,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
||||
@@ -19,7 +19,6 @@ and adapts it to the specificities of MaskedBert (`pruning_method`, `mask_init`
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
@@ -31,7 +30,6 @@ class MaskedBertConfig(PretrainedConfig):
|
||||
A class replicating the `~transformers.BertConfig` with additional parameters for pruning/masking configuration.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
||||
model_type = "masked_bert"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -29,12 +29,7 @@ from torch.nn import CrossEntropyLoss, MSELoss
|
||||
from emmental import MaskedBertConfig
|
||||
from emmental.modules import MaskedLinear
|
||||
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_callable
|
||||
from transformers.modeling_bert import (
|
||||
ACT2FN,
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
BertLayerNorm,
|
||||
load_tf_weights_in_bert,
|
||||
)
|
||||
from transformers.modeling_bert import ACT2FN, BertLayerNorm, load_tf_weights_in_bert
|
||||
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
|
||||
|
||||
|
||||
@@ -395,7 +390,6 @@ class MaskedBertPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
|
||||
config_class = MaskedBertConfig
|
||||
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
||||
@@ -53,8 +53,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForSequenceClassification, BertTokenizer),
|
||||
@@ -576,7 +574,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task_name",
|
||||
|
||||
@@ -57,8 +57,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig,)), (),)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
|
||||
"masked_bert": (MaskedBertConfig, MaskedBertForQuestionAnswering, BertTokenizer),
|
||||
@@ -673,7 +671,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
||||
@@ -58,8 +58,6 @@ logger = logging.getLogger(__name__)
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_QUESTION_ANSWERING_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in MODEL_CONFIG_CLASSES), (),)
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
@@ -491,7 +489,7 @@ def main():
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
|
||||
@@ -61,7 +61,6 @@ class BertAbsConfig(PretrainedConfig):
|
||||
the decoder.
|
||||
"""
|
||||
|
||||
pretrained_config_archive_map = BERTABS_FINETUNED_CONFIG_MAP
|
||||
model_type = "bertabs"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -33,14 +33,13 @@ from transformers import BertConfig, BertModel, PreTrainedModel
|
||||
|
||||
MAX_SIZE = 5000
|
||||
|
||||
BERTABS_FINETUNED_MODEL_MAP = {
|
||||
"bertabs-finetuned-cnndm": "https://cdn.huggingface.co/remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization/pytorch_model.bin",
|
||||
}
|
||||
BERTABS_FINETUNED_MODEL_ARCHIVE_LIST = [
|
||||
"remi/bertabs-finetuned-cnndm-extractive-abstractive-summarization",
|
||||
]
|
||||
|
||||
|
||||
class BertAbsPreTrainedModel(PreTrainedModel):
|
||||
config_class = BertAbsConfig
|
||||
pretrained_model_archive_map = BERTABS_FINETUNED_MODEL_MAP
|
||||
load_tf_weights = False
|
||||
base_model_prefix = "bert"
|
||||
|
||||
|
||||
@@ -258,7 +258,7 @@ TEST RESULTS {'val_loss': tensor(0.0707), 'precision': 0.852427800698191, 'recal
|
||||
|
||||
Based on the script [`run_xnli.py`](https://github.com/huggingface/transformers/blob/master/examples/text-classification/run_xnli.py).
|
||||
|
||||
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili).
|
||||
[XNLI](https://www.nyu.edu/projects/bowman/xnli/) is a crowd-sourced dataset based on [MultiNLI](http://www.nyu.edu/projects/bowman/multinli/). It is an evaluation benchmark for cross-lingual text representations. Pairs of text are labeled with textual entailment annotations for 15 different languages (including both high-resource language such as English and low-resource languages such as Swahili).
|
||||
|
||||
#### Fine-tuning on XNLI
|
||||
|
||||
@@ -273,7 +273,6 @@ on a single tesla V100 16GB. The data for XNLI can be downloaded with the follow
|
||||
export XNLI_DIR=/path/to/XNLI
|
||||
|
||||
python run_xnli.py \
|
||||
--model_type bert \
|
||||
--model_name_or_path bert-base-multilingual-cased \
|
||||
--language de \
|
||||
--train_language en \
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning multi-lingual models on XNLI (Bert, DistilBERT, XLM).
|
||||
""" Finetuning multi-lingual models on XNLI (e.g. Bert, DistilBERT, XLM).
|
||||
Adapted from `examples/text-classification/run_glue.py`"""
|
||||
|
||||
|
||||
@@ -32,15 +32,9 @@ from tqdm import tqdm, trange
|
||||
from transformers import (
|
||||
WEIGHTS_NAME,
|
||||
AdamW,
|
||||
BertConfig,
|
||||
BertForSequenceClassification,
|
||||
BertTokenizer,
|
||||
DistilBertConfig,
|
||||
DistilBertForSequenceClassification,
|
||||
DistilBertTokenizer,
|
||||
XLMConfig,
|
||||
XLMForSequenceClassification,
|
||||
XLMTokenizer,
|
||||
AutoConfig,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
||||
@@ -57,16 +51,6 @@ except ImportError:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ALL_MODELS = sum(
|
||||
(tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, DistilBertConfig, XLMConfig)), ()
|
||||
)
|
||||
|
||||
MODEL_CLASSES = {
|
||||
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
||||
"xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
||||
"distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
||||
}
|
||||
|
||||
|
||||
def set_seed(args):
|
||||
random.seed(args.seed)
|
||||
@@ -377,19 +361,12 @@ def main():
|
||||
required=True,
|
||||
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language",
|
||||
@@ -421,7 +398,7 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
default="",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
@@ -562,24 +539,23 @@ def main():
|
||||
if args.local_rank not in [-1, 0]:
|
||||
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
||||
|
||||
args.model_type = args.model_type.lower()
|
||||
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
||||
config = config_class.from_pretrained(
|
||||
config = AutoConfig.from_pretrained(
|
||||
args.config_name if args.config_name else args.model_name_or_path,
|
||||
num_labels=num_labels,
|
||||
finetuning_task=args.task_name,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
tokenizer = tokenizer_class.from_pretrained(
|
||||
args.model_type = config.model_type
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
||||
do_lower_case=args.do_lower_case,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
model = model_class.from_pretrained(
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in args.model_name_or_path),
|
||||
config=config,
|
||||
cache_dir=args.cache_dir if args.cache_dir else None,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
if args.local_rank == 0:
|
||||
@@ -614,14 +590,13 @@ def main():
|
||||
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
||||
|
||||
# Load a trained model and vocabulary that you have fine-tuned
|
||||
model = model_class.from_pretrained(args.output_dir)
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(args.output_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.output_dir)
|
||||
model.to(args.device)
|
||||
|
||||
# Evaluation
|
||||
results = {}
|
||||
if args.do_eval and args.local_rank in [-1, 0]:
|
||||
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
||||
checkpoints = [args.output_dir]
|
||||
if args.eval_all_checkpoints:
|
||||
checkpoints = list(
|
||||
@@ -633,7 +608,7 @@ def main():
|
||||
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
||||
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
||||
|
||||
model = model_class.from_pretrained(checkpoint)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
|
||||
model.to(args.device)
|
||||
result = evaluate(args, model, tokenizer, prefix=prefix)
|
||||
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Fine-tuning the library models for named entity recognition on CoNLL-2003 (Bert or Roberta). """
|
||||
""" Fine-tuning the library models for named entity recognition on CoNLL-2003. """
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
Reference in New Issue
Block a user