Add support for Albert and XLMRoberta for the Glue example (#2403)
* Add support for Albert and XLMRoberta for the Glue example
This commit is contained in:
committed by
Lysandre Debut
parent
9261c7f771
commit
176d3b3079
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa)."""
|
||||
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""
|
||||
|
||||
|
||||
import argparse
|
||||
@@ -72,7 +72,15 @@ logger = logging.getLogger(__name__)
|
||||
ALL_MODELS = sum(
|
||||
(
|
||||
tuple(conf.pretrained_config_archive_map.keys())
|
||||
for conf in (BertConfig, XLNetConfig, XLMConfig, RobertaConfig, DistilBertConfig)
|
||||
for conf in (
|
||||
BertConfig,
|
||||
XLNetConfig,
|
||||
XLMConfig,
|
||||
RobertaConfig,
|
||||
DistilBertConfig,
|
||||
AlbertConfig,
|
||||
XLMRobertaConfig,
|
||||
)
|
||||
),
|
||||
(),
|
||||
)
|
||||
@@ -148,7 +156,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
# Distributed training (should be after apex fp16 initialization)
|
||||
if args.local_rank != -1:
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
|
||||
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
|
||||
)
|
||||
|
||||
# Train!
|
||||
@@ -183,7 +191,7 @@ def train(args, train_dataset, model, tokenizer):
|
||||
tr_loss, logging_loss = 0.0, 0.0
|
||||
model.zero_grad()
|
||||
train_iterator = trange(
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]
|
||||
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
|
||||
)
|
||||
set_seed(args) # Added here for reproductibility
|
||||
for _ in train_iterator:
|
||||
@@ -200,8 +208,8 @@ def train(args, train_dataset, model, tokenizer):
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
|
||||
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
||||
|
||||
@@ -316,8 +324,8 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
||||
if args.model_type != "distilbert":
|
||||
inputs["token_type_ids"] = (
|
||||
batch[2] if args.model_type in ["bert", "xlnet"] else None
|
||||
) # XLM, DistilBERT and RoBERTa don't use segment_ids
|
||||
batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
|
||||
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
||||
outputs = model(**inputs)
|
||||
tmp_eval_loss, logits = outputs[:2]
|
||||
|
||||
@@ -448,7 +456,7 @@ def main():
|
||||
|
||||
# Other parameters
|
||||
parser.add_argument(
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name"
|
||||
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
@@ -472,15 +480,17 @@ def main():
|
||||
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
||||
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
||||
parser.add_argument(
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step."
|
||||
"--evaluate_during_training", action="store_true", help="Rul evaluation during training at each logging step.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model."
|
||||
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
|
||||
)
|
||||
|
||||
parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation."
|
||||
"--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
@@ -493,7 +503,7 @@ def main():
|
||||
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument(
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform."
|
||||
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_steps",
|
||||
@@ -512,10 +522,10 @@ def main():
|
||||
)
|
||||
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
||||
parser.add_argument(
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory"
|
||||
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets"
|
||||
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user