Trainer (#3800)
* doc
* [tests] Add sample files for a regression task
* [HUGE] Trainer
* Feedback from @sshleifer
* Feedback from @thomwolf + logging tweak
* [file_utils] when downloading concurrently, get_from_cache will use the cached file for subsequent processes
* [glue] Use default max_seq_length of 128 like before
* [glue] move DataTrainingArguments around
* [ner] Change interface of InputExample, and align run_{tf,pl}
* Re-align the pl scripts a little bit
* ner
* [ner] Add integration test
* Fix language_modeling with API tweak
* [ci] Tweak loss target
* Don't break console output
* amp.initialize: model must be on right device before
* [multiple-choice] update for Trainer
* Re-align to 827d6d6ef0
This commit is contained in:
@@ -8,7 +8,6 @@ import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
ALL_PRETRAINED_MODEL_ARCHIVE_MAP,
|
||||
AdamW,
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
@@ -20,15 +19,11 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
get_linear_schedule_with_warmup,
|
||||
)
|
||||
from transformers.modeling_auto import MODEL_MAPPING
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
ALL_MODELS = tuple(ALL_PRETRAINED_MODEL_ARCHIVE_MAP)
|
||||
MODEL_CLASSES = tuple(m.model_type for m in MODEL_MAPPING)
|
||||
|
||||
MODEL_MODES = {
|
||||
"base": AutoModel,
|
||||
"sequence-classification": AutoModelForSequenceClassification,
|
||||
@@ -51,28 +46,25 @@ class BaseTransformer(pl.LightningModule):
|
||||
def __init__(self, hparams: argparse.Namespace, num_labels=None, mode="base", **config_kwargs):
|
||||
"Initialize a model."
|
||||
|
||||
super(BaseTransformer, self).__init__()
|
||||
super().__init__()
|
||||
self.hparams = hparams
|
||||
cache_dir = self.hparams.cache_dir if self.hparams.cache_dir else None
|
||||
self.hparams.model_type = self.hparams.model_type.lower()
|
||||
config = AutoConfig.from_pretrained(
|
||||
self.config = AutoConfig.from_pretrained(
|
||||
self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path,
|
||||
**({"num_labels": num_labels} if num_labels is not None else {}),
|
||||
cache_dir=cache_dir,
|
||||
**config_kwargs,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path,
|
||||
do_lower_case=self.hparams.do_lower_case,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
model = MODEL_MODES[mode].from_pretrained(
|
||||
self.model = MODEL_MODES[mode].from_pretrained(
|
||||
self.hparams.model_name_or_path,
|
||||
from_tf=bool(".ckpt" in self.hparams.model_name_or_path),
|
||||
config=config,
|
||||
config=self.config,
|
||||
cache_dir=cache_dir,
|
||||
)
|
||||
self.config, self.tokenizer, self.model = config, tokenizer, model
|
||||
|
||||
def is_logger(self):
|
||||
return self.trainer.proc_rank <= 0
|
||||
@@ -148,19 +140,12 @@ class BaseTransformer(pl.LightningModule):
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parser, root_dir):
|
||||
parser.add_argument(
|
||||
"--model_type",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS),
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
@@ -177,9 +162,6 @@ class BaseTransformer(pl.LightningModule):
|
||||
type=str,
|
||||
help="Where do you want to store the pre-trained models downloaded from s3",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
)
|
||||
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
||||
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
@@ -252,8 +234,6 @@ def add_generic_args(parser, root_dir):
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
|
||||
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
|
||||
@@ -261,15 +241,6 @@ def generic_train(model: BaseTransformer, args: argparse.Namespace):
|
||||
# init model
|
||||
set_seed(args)
|
||||
|
||||
# Setup distant debugging if needed
|
||||
if args.server_ip and args.server_port:
|
||||
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
||||
import ptvsd
|
||||
|
||||
print("Waiting for debugger attach")
|
||||
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
||||
ptvsd.wait_for_attach()
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user