From df3961121f23f30a69979720c493c491a5c482ed Mon Sep 17 00:00:00 2001 From: Suvrat Bhooshan Date: Mon, 9 Dec 2019 18:25:28 -0800 Subject: [PATCH] Add MMBT Model to Transformers Repo --- README.md | 3 +- examples/README.md | 25 ++ examples/run_mmimdb.py | 504 +++++++++++++++++++++++++++++ examples/utils_mmimdb.py | 130 ++++++++ transformers/__init__.py | 3 + transformers/configuration_mmbt.py | 38 +++ transformers/modeling_mmbt.py | 368 +++++++++++++++++++++ 7 files changed, 1070 insertions(+), 1 deletion(-) create mode 100644 examples/run_mmimdb.py create mode 100644 examples/utils_mmimdb.py create mode 100644 transformers/configuration_mmbt.py create mode 100644 transformers/modeling_mmbt.py diff --git a/README.md b/README.md index f3aa8a95ee..6c3095bfe6 100644 --- a/README.md +++ b/README.md @@ -144,7 +144,8 @@ At some point in the future, you'll be able to seamlessly move from pre-training 9. **[CTRL](https://github.com/salesforce/ctrl/)** (from Salesforce) released with the paper [CTRL: A Conditional Transformer Language Model for Controllable Generation](https://arxiv.org/abs/1909.05858) by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher. 10. **[CamemBERT](https://camembert-model.fr)** (from Inria/Facebook/Sorbonne) released with the paper [CamemBERT: a Tasty French Language Model](https://arxiv.org/abs/1911.03894) by Louis Martin*, Benjamin Muller*, Pedro Javier Ortiz Suárez*, Yoann Dupont, Laurent Romary, Éric Villemonte de la Clergerie, Djamé Seddah and Benoît Sagot. 11. **[ALBERT](https://github.com/google-research/ALBERT)** (from Google Research and the Toyota Technological Institute at Chicago) released with the paper [ALBERT: A Lite BERT for Self-supervised Learning of Language Representations](https://arxiv.org/abs/1909.11942), by Zhenzhong Lan, Mingda Chen, Sebastian Goodman, Kevin Gimpel, Piyush Sharma, Radu Soricut. -11. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. +12. **[MMBT](https://github.com/facebookresearch/mmbt/)** (from Facebook), released together with the paper a [Supervised Multimodal Bitransformers for Classifying Images and Text](https://arxiv.org/pdf/1909.02950.pdf) by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine. +12. Want to contribute a new model? We have added a **detailed guide and templates** to guide you in the process of adding a new model. You can find them in the [`templates`](./templates) folder of the repository. Be sure to check the [contributing guidelines](./CONTRIBUTING.md) and contact the maintainers or open an issue to collect feedbacks before starting your PR. These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/transformers/examples.html). diff --git a/examples/README.md b/examples/README.md index 620304ea77..1a7912296f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -713,3 +713,28 @@ Training with the previously defined hyper-parameters yields the following resul ```bash acc = 0.7093812375249501 ``` + +## MM-IMDb + +Based on the script [`run_mmimdb.py`](https://github.com/huggingface/transformers/blob/master/examples/run_mmimdb.py). + +[MM-IMDb](http://lisi1.unal.edu.co/mmimdb/) is a Multimodal dataset with around 26,000 movies including images, plots and other metadata. + +### Training on MM-IMDb + +``` +python run_mmimdb.py \ + --data_dir /path/to/mmimdb/dataset/ \ + --model_type bert \ + --model_name_or_path bert-base-uncased \ + --output_dir /path/to/save/dir/ \ + --do_train \ + --do_eval \ + --max_seq_len 512 \ + --gradient_accumulation_steps 20 \ + --num_image_embeds 3 \ + --num_train_epochs 100 \ + --patience 5 +``` + + diff --git a/examples/run_mmimdb.py b/examples/run_mmimdb.py new file mode 100644 index 0000000000..f4a44bf62a --- /dev/null +++ b/examples/run_mmimdb.py @@ -0,0 +1,504 @@ +# coding=utf-8 +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for multimodal multiclass prediction on MM-IMDB dataset.""" + +from __future__ import absolute_import, division, print_function + +import argparse +import glob +import logging +import os +import random +import json +from sklearn.metrics import f1_score + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data.distributed import DistributedSampler + +try: + from torch.utils.tensorboard import SummaryWriter +except: + from tensorboardX import SummaryWriter + +from tqdm import tqdm, trange + +from utils_mmimdb import ImageEncoder, JsonlDataset, collate_fn, get_mmimdb_labels, get_image_transforms + +from transformers import (WEIGHTS_NAME, + BertConfig, BertModel, BertTokenizer, + RobertaConfig, RobertaModel, RobertaTokenizer, + XLMConfig, XLMModel, XLMTokenizer, + XLNetConfig, XLNetModel, XLNetTokenizer, + DistilBertConfig, DistilBertModel, DistilBertTokenizer, + AlbertConfig, AlbertModel, AlbertTokenizer, + MMBTForClassification, MMBTConfig) + +from transformers import AdamW, get_linear_schedule_with_warmup + +logger = logging.getLogger(__name__) + +ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (BertConfig, XLNetConfig, XLMConfig, + RobertaConfig, DistilBertConfig)), ()) + +MODEL_CLASSES = { + 'bert': (BertConfig, BertModel, BertTokenizer), + 'xlnet': (XLNetConfig, XLNetModel, XLNetTokenizer), + 'xlm': (XLMConfig, XLMModel, XLMTokenizer), + 'roberta': (RobertaConfig, RobertaModel, RobertaTokenizer), + 'distilbert': (DistilBertConfig, DistilBertModel, DistilBertTokenizer), + 'albert': (AlbertConfig, AlbertModel, AlbertTokenizer) +} + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if args.n_gpu > 0: + torch.cuda.manual_seed_all(args.seed) + + +def train(args, train_dataset, model, tokenizer, criterion): + """ Train the model """ + if args.local_rank in [-1, 0]: + tb_writer = SummaryWriter() + + args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) + train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset) + train_dataloader = DataLoader(train_dataset, sampler=train_sampler, + batch_size=args.train_batch_size, + collate_fn=collate_fn, + num_workers=args.num_workers) + + if args.max_steps > 0: + t_total = args.max_steps + args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1 + else: + t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs + + # Prepare optimizer and schedule (linear warmup and decay) + no_decay = ['bias', 'LayerNorm.weight'] + optimizer_grouped_parameters = [ + {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, + {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} + ] + + optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) + scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) + if args.fp16: + try: + from apex import amp + except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") + model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) + + # multi-gpu training (should be after apex fp16 initialization) + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Distributed training (should be after apex fp16 initialization) + if args.local_rank != -1: + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True) + + # Train! + logger.info("***** Running training *****") + logger.info(" Num examples = %d", len(train_dataset)) + logger.info(" Num Epochs = %d", args.num_train_epochs) + logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) + logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", + args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1)) + logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) + logger.info(" Total optimization steps = %d", t_total) + + global_step = 0 + tr_loss, logging_loss = 0.0, 0.0 + best_f1, n_no_improve = 0, 0 + model.zero_grad() + train_iterator = trange(int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0]) + set_seed(args) # Added here for reproductibility (even between python 2 and 3) + for _ in train_iterator: + epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0]) + for step, batch in enumerate(epoch_iterator): + model.train() + batch = tuple(t.to(args.device) for t in batch) + labels = batch[5] + inputs = {'input_ids': batch[0], + 'input_modal': batch[2], + 'attention_mask': batch[1], + 'modal_start_tokens': batch[3], + 'modal_end_tokens': batch[4]} + outputs = model(**inputs) + logits = outputs[0] # model outputs are always tuple in transformers (see doc) + loss = criterion(logits, labels) + + if args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + if args.gradient_accumulation_steps > 1: + loss = loss / args.gradient_accumulation_steps + + if args.fp16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + tr_loss += loss.item() + if (step + 1) % args.gradient_accumulation_steps == 0: + if args.fp16: + torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) + else: + torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) + + optimizer.step() + scheduler.step() # Update learning rate schedule + model.zero_grad() + global_step += 1 + + if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0: + logs = {} + if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well + results = evaluate(args, model, tokenizer, criterion) + for key, value in results.items(): + eval_key = 'eval_{}'.format(key) + logs[eval_key] = value + + loss_scalar = (tr_loss - logging_loss) / args.logging_steps + learning_rate_scalar = scheduler.get_lr()[0] + logs['learning_rate'] = learning_rate_scalar + logs['loss'] = loss_scalar + logging_loss = tr_loss + + for key, value in logs.items(): + tb_writer.add_scalar(key, value, global_step) + print(json.dumps({**logs, **{'step': global_step}})) + + if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0: + # Save model checkpoint + output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step)) + if not os.path.exists(output_dir): + os.makedirs(output_dir) + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + torch.save(model_to_save.state_dict(), os.path.join(output_dir, WEIGHTS_NAME)) + torch.save(args, os.path.join(output_dir, 'training_args.bin')) + logger.info("Saving model checkpoint to %s", output_dir) + + if args.max_steps > 0 and global_step > args.max_steps: + epoch_iterator.close() + break + if args.max_steps > 0 and global_step > args.max_steps: + train_iterator.close() + break + + if args.local_rank == -1: + results = evaluate(args, model, tokenizer, criterion) + if results['micro_f1'] > best_f1: + best_f1 = results['micro_f1'] + n_no_improve = 0 + else: + n_no_improve += 1 + + if n_no_improve > args.patience: + train_iterator.close() + break + + if args.local_rank in [-1, 0]: + tb_writer.close() + + return global_step, tr_loss / global_step + + +def evaluate(args, model, tokenizer, criterion, prefix=""): + # Loop to handle MNLI double evaluation (matched, mis-matched) + eval_output_dir = args.output_dir + eval_dataset = load_examples(args, tokenizer, evaluate=True) + + if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]: + os.makedirs(eval_output_dir) + + args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) + # Note that DistributedSampler samples randomly + eval_sampler = SequentialSampler(eval_dataset) + eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size, collate_fn=collate_fn) + + # multi-gpu eval + if args.n_gpu > 1: + model = torch.nn.DataParallel(model) + + # Eval! + logger.info("***** Running evaluation {} *****".format(prefix)) + logger.info(" Num examples = %d", len(eval_dataset)) + logger.info(" Batch size = %d", args.eval_batch_size) + eval_loss = 0.0 + nb_eval_steps = 0 + preds = None + out_label_ids = None + for batch in tqdm(eval_dataloader, desc="Evaluating"): + model.eval() + batch = tuple(t.to(args.device) for t in batch) + + with torch.no_grad(): + batch = tuple(t.to(args.device) for t in batch) + labels = batch[5] + inputs = {'input_ids': batch[0], + 'input_modal': batch[2], + 'attention_mask': batch[1], + 'modal_start_tokens': batch[3], + 'modal_end_tokens': batch[4]} + outputs = model(**inputs) + logits = outputs[0] # model outputs are always tuple in transformers (see doc) + tmp_eval_loss = criterion(logits, labels) + eval_loss += tmp_eval_loss.mean().item() + nb_eval_steps += 1 + if preds is None: + preds = torch.sigmoid(logits).detach().cpu().numpy() > 0.5 + out_label_ids = labels.detach().cpu().numpy() + else: + preds = np.append(preds, torch.sigmoid(logits).detach().cpu().numpy() > 0.5, axis=0) + out_label_ids = np.append(out_label_ids, labels.detach().cpu().numpy(), axis=0) + + eval_loss = eval_loss / nb_eval_steps + result = { + "loss": eval_loss, + "macro_f1": f1_score(out_label_ids, preds, average="macro"), + "micro_f1": f1_score(out_label_ids, preds, average="micro") + } + + output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt") + with open(output_eval_file, "w") as writer: + logger.info("***** Eval results {} *****".format(prefix)) + for key in sorted(result.keys()): + logger.info(" %s = %s", key, str(result[key])) + writer.write("%s = %s\n" % (key, str(result[key]))) + + return result + + +def load_examples(args, tokenizer, evaluate=False): + path = os.path.join(args.data_dir, "dev.jsonl" if evaluate else "train.jsonl") + transforms = get_image_transforms() + labels = get_mmimdb_labels() + dataset = JsonlDataset(path, tokenizer, transforms, labels, args.max_seq_length - args.num_image_embeds - 2) + return dataset + + +def main(): + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--data_dir", default=None, type=str, required=True, + help="The input data dir. Should contain the .jsonl files for MMIMDB.") + parser.add_argument("--model_type", default=None, type=str, required=True, + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys())) + parser.add_argument("--model_name_or_path", default=None, type=str, required=True, + help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS)) + parser.add_argument("--output_dir", default=None, type=str, required=True, + help="The output directory where the model predictions and checkpoints will be written.") + + ## Other parameters + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name") + parser.add_argument("--tokenizer_name", default="", type=str, + help="Pretrained tokenizer name or path if not the same as model_name") + parser.add_argument("--cache_dir", default="", type=str, + help="Where do you want to store the pre-trained models downloaded from s3") + parser.add_argument("--max_seq_length", default=128, type=int, + help="The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded.") + parser.add_argument("--num_image_embeds", default=1, type=int, + help="Number of Image Embeddings from the Image Encoder") + parser.add_argument("--do_train", action='store_true', + help="Whether to run training.") + parser.add_argument("--do_eval", action='store_true', + help="Whether to run eval on the dev set.") + parser.add_argument("--evaluate_during_training", action='store_true', + help="Rul evaluation during training at each logging step.") + parser.add_argument("--do_lower_case", action='store_true', + help="Set this flag if you are using an uncased model.") + + parser.add_argument("--per_gpu_train_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument('--gradient_accumulation_steps', type=int, default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.") + parser.add_argument("--learning_rate", default=5e-5, type=float, + help="The initial learning rate for Adam.") + parser.add_argument("--weight_decay", default=0.0, type=float, + help="Weight deay if we apply some.") + parser.add_argument("--adam_epsilon", default=1e-8, type=float, + help="Epsilon for Adam optimizer.") + parser.add_argument("--max_grad_norm", default=1.0, type=float, + help="Max gradient norm.") + parser.add_argument("--num_train_epochs", default=3.0, type=float, + help="Total number of training epochs to perform.") + parser.add_argument("--patience", default=5, type=int, + help="Patience for Early Stopping.") + parser.add_argument("--max_steps", default=-1, type=int, + help="If > 0: set total number of training steps to perform. Override num_train_epochs.") + parser.add_argument("--warmup_steps", default=0, type=int, + help="Linear warmup over warmup_steps.") + + parser.add_argument('--logging_steps', type=int, default=50, + help="Log every X updates steps.") + parser.add_argument('--save_steps', type=int, default=50, + help="Save checkpoint every X updates steps.") + parser.add_argument("--eval_all_checkpoints", action='store_true', + help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number") + parser.add_argument("--no_cuda", action='store_true', + help="Avoid using CUDA when available") + parser.add_argument('--num_workers', type=int, default=8, + help="number of worker threads for dataloading") + parser.add_argument('--overwrite_output_dir', action='store_true', + help="Overwrite the content of the output directory") + parser.add_argument('--overwrite_cache', action='store_true', + help="Overwrite the cached training and evaluation sets") + parser.add_argument('--seed', type=int, default=42, + help="random seed for initialization") + + parser.add_argument('--fp16', action='store_true', + help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit") + parser.add_argument('--fp16_opt_level', type=str, default='O1', + help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." + "See details at https://nvidia.github.io/apex/amp.html") + parser.add_argument("--local_rank", type=int, default=-1, + help="For distributed training: local_rank") + parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.") + parser.add_argument('--server_port', type=str, default='', help="For distant debugging.") + args = parser.parse_args() + + if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir: + raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir)) + + # Setup distant debugging if needed + if args.server_ip and args.server_port: + # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script + import ptvsd + print("Waiting for debugger attach") + ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) + ptvsd.wait_for_attach() + + # Setup CUDA, GPU & distributed training + if args.local_rank == -1 or args.no_cuda: + device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + args.n_gpu = torch.cuda.device_count() + else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs + torch.cuda.set_device(args.local_rank) + device = torch.device("cuda", args.local_rank) + torch.distributed.init_process_group(backend='nccl') + args.n_gpu = 1 + + args.device = device + + # Setup logging + logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN) + logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s", + args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16) + + # Set seed + set_seed(args) + + # Load pretrained model and tokenizer + if args.local_rank not in [-1, 0]: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + # Setup model + labels = get_mmimdb_labels() + num_labels = len(labels) + args.model_type = args.model_type.lower() + config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type] + transformer_config = config_class.from_pretrained(args.config_name if args.config_name else args.model_name_or_path) + tokenizer = tokenizer_class.from_pretrained(args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, + do_lower_case=args.do_lower_case, + cache_dir=args.cache_dir if args.cache_dir else None) + transformer = model_class.from_pretrained(args.model_name_or_path, + config=transformer_config, + cache_dir=args.cache_dir if args.cache_dir else None) + img_encoder = ImageEncoder(args) + config = MMBTConfig(transformer_config, num_labels=num_labels) + model = MMBTForClassification(config, transformer, img_encoder) + + if args.local_rank == 0: + torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab + + model.to(args.device) + + logger.info("Training/evaluation parameters %s", args) + + # Training + if args.do_train: + train_dataset = load_examples(args, tokenizer, evaluate=False) + label_frequences = train_dataset.get_label_frequencies() + label_frequences = [label_frequences[l] for l in labels] + label_weights = (torch.tensor(label_frequences, device=args.device, dtype=torch.float) / len(train_dataset)) ** -1 + criterion = nn.BCEWithLogitsLoss(pos_weight=label_weights) + global_step, tr_loss = train(args, train_dataset, model, tokenizer, criterion) + logger.info(" global_step = %s, average loss = %s", global_step, tr_loss) + + + # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() + if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0): + # Create output directory if needed + if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]: + os.makedirs(args.output_dir) + + logger.info("Saving model checkpoint to %s", args.output_dir) + # Save a trained model, configuration and tokenizer using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training + torch.save(model_to_save.state_dict(), os.path.join(args.output_dir, WEIGHTS_NAME)) + tokenizer.save_pretrained(args.output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(args, os.path.join(args.output_dir, 'training_args.bin')) + + # Load a trained model and vocabulary that you have fine-tuned + model = MMBTForClassification(config, transformer, img_encoder) + model.load_state_dict(torch.load(os.path.join(args.output_dir, WEIGHTS_NAME))) + tokenizer = tokenizer_class.from_pretrained(args.output_dir) + model.to(args.device) + + + # Evaluation + results = {} + if args.do_eval and args.local_rank in [-1, 0]: + tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case) + checkpoints = [args.output_dir] + if args.eval_all_checkpoints: + checkpoints = list(os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + '/**/' + WEIGHTS_NAME, recursive=True))) + logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging + logger.info("Evaluate the following checkpoints: %s", checkpoints) + for checkpoint in checkpoints: + global_step = checkpoint.split('-')[-1] if len(checkpoints) > 1 else "" + prefix = checkpoint.split('/')[-1] if checkpoint.find('checkpoint') != -1 else "" + model = MMBTForClassification(config, transformer, img_encoder) + model.load_state_dict(torch.load(checkpoint)) + model.to(args.device) + result = evaluate(args, model, tokenizer, criterion, prefix=prefix) + result = dict((k + '_{}'.format(global_step), v) for k, v in result.items()) + results.update(result) + + return results + + +if __name__ == "__main__": + main() diff --git a/examples/utils_mmimdb.py b/examples/utils_mmimdb.py new file mode 100644 index 0000000000..c59da02642 --- /dev/null +++ b/examples/utils_mmimdb.py @@ -0,0 +1,130 @@ +# coding=utf-8 +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections import Counter +from PIL import Image + +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import Dataset + +POOLING_BREAKDOWN = { + 1: (1, 1), + 2: (2, 1), + 3: (3, 1), + 4: (2, 2), + 5: (5, 1), + 6: (3, 2), + 7: (7, 1), + 8: (4, 2), + 9: (3, 3) +} + + +class ImageEncoder(nn.Module): + def __init__(self, args): + super(ImageEncoder, self).__init__() + model = torchvision.models.resnet152(pretrained=True) + modules = list(model.children())[:-2] + self.model = nn.Sequential(*modules) + self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds]) + + def forward(self, x): + # Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048 + out = self.pool(self.model(x)) + out = torch.flatten(out, start_dim=2) + out = out.transpose(1, 2).contiguous() + return out # BxNx2048 + + + +class JsonlDataset(Dataset): + def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length): + self.data = [json.loads(l) for l in open(data_path)] + self.data_dir = os.path.dirname(data_path) + self.tokenizer = tokenizer + self.labels = labels + self.n_classes = len(labels) + self.max_seq_length = max_seq_length + + self.transforms = transforms + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True)) + start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1] + sentence = sentence[:self.max_seq_length] + + label = torch.zeros(self.n_classes) + label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1 + + image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB") + image = self.transforms(image) + + return {"image_start_token": start_token, "image_end_token": end_token, + "sentence": sentence, "image": image, "label": label} + + def get_label_frequencies(self): + label_freqs = Counter() + for row in self.data: + label_freqs.update(row["label"]) + return label_freqs + + +def collate_fn(batch): + lens = [len(row["sentence"]) for row in batch] + bsz, max_seq_len = len(batch), max(lens) + + mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long) + text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long) + + for i_batch, (input_row, length) in enumerate(zip(batch, lens)): + text_tensor[i_batch, :length] = input_row["sentence"] + mask_tensor[i_batch, :length] = 1 + + img_tensor = torch.stack([row["image"] for row in batch]) + tgt_tensor = torch.stack([row["label"] for row in batch]) + img_start_token = torch.stack([row["image_start_token"] for row in batch]) + img_end_token = torch.stack([row["image_end_token"] for row in batch]) + + return text_tensor, mask_tensor, img_tensor, img_start_token, img_end_token, tgt_tensor + + +def get_mmimdb_labels(): + return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance', + 'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure', + 'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music', + 'Musical', 'Animation', 'Biography', 'Film-Noir'] + + +def get_image_transforms(): + return transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.46777044, 0.44531429, 0.40661017], + std=[0.12221994, 0.12145835, 0.14380469], + ), + ] + ) diff --git a/transformers/__init__.py b/transformers/__init__.py index 1105756456..5f4b8252a4 100644 --- a/transformers/__init__.py +++ b/transformers/__init__.py @@ -61,6 +61,7 @@ from .configuration_roberta import RobertaConfig, ROBERTA_PRETRAINED_CONFIG_ARCH from .configuration_distilbert import DistilBertConfig, DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_albert import AlbertConfig, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP from .configuration_camembert import CamembertConfig, CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP +from .configuration_mmbt import MMBTConfig # Modeling if is_torch_available(): @@ -112,6 +113,8 @@ if is_torch_available(): AlbertForQuestionAnswering, load_tf_weights_in_albert, ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP) + from .modeling_mmbt import ModalEmbeddings, MMBTModel, MMBTForClassification + # Optimization from .optimization import (AdamW, get_constant_schedule, get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup, get_cosine_with_hard_restarts_schedule_with_warmup, get_linear_schedule_with_warmup) diff --git a/transformers/configuration_mmbt.py b/transformers/configuration_mmbt.py new file mode 100644 index 0000000000..60176c9872 --- /dev/null +++ b/transformers/configuration_mmbt.py @@ -0,0 +1,38 @@ +# coding=utf-8 +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" MMBT configuration """ + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import logging + +logger = logging.getLogger(__name__) + + +class MMBTConfig(object): + """Configuration class to store the configuration of a `MMBT Model`. + + Args: + config: config of the underlying Transformer models. It's values are copied over to use a single config. + num_labels: Size of final Linear layer for classification. + modal_hidden_size: Embedding dimension of the non-text modality encoder. + """ + def __init__(self, config, num_labels=None, modal_hidden_size=2048): + self.__dict__ = config.__dict__ + self.modal_hidden_size = modal_hidden_size + if num_labels: + self.num_labels = num_labels diff --git a/transformers/modeling_mmbt.py b/transformers/modeling_mmbt.py new file mode 100644 index 0000000000..79a717ba2a --- /dev/null +++ b/transformers/modeling_mmbt.py @@ -0,0 +1,368 @@ +# coding=utf-8 +# Copyright (c) Facebook, Inc. and its affiliates. +# Copyright (c) HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch MMBT model. """ + +from __future__ import (absolute_import, division, print_function, + unicode_literals) + +import logging + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss, MSELoss + +from .file_utils import add_start_docstrings + +logger = logging.getLogger(__name__) + + +class ModalEmbeddings(nn.Module): + """Generic Modal Embeddings which takes in an encoder, and a transformer embedding. + """ + def __init__(self, config, encoder, embeddings): + super(ModalEmbeddings, self).__init__() + self.config = config + self.encoder = encoder + self.proj_embeddings = nn.Linear(config.modal_hidden_size, config.hidden_size) + self.position_embeddings = embeddings.position_embeddings + self.token_type_embeddings = embeddings.token_type_embeddings + self.word_embeddings = embeddings.word_embeddings + self.LayerNorm = embeddings.LayerNorm + self.dropout = nn.Dropout(p=config.hidden_dropout_prob) + + def forward(self, input_modal, start_token=None, end_token=None, position_ids=None, token_type_ids=None): + token_embeddings = self.proj_embeddings(self.encoder(input_modal)) + seq_length = token_embeddings.size(1) + + if start_token is not None: + start_token_embeds = self.word_embeddings(start_token) + seq_length += 1 + token_embeddings = torch.cat([start_token_embeds.unsqueeze(1), token_embeddings], dim=1) + + if end_token is not None: + end_token_embeds = self.word_embeddings(end_token) + seq_length += 1 + token_embeddings = torch.cat([token_embeddings, end_token_embeds.unsqueeze(1)], dim=1) + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_modal.device) + position_ids = position_ids.unsqueeze(0).expand(input_modal.size(0), seq_length) + + if token_type_ids is None: + token_type_ids = torch.zeros((input_modal.size(0), seq_length), dtype=torch.long, device=input_modal.device) + + position_embeddings = self.position_embeddings(position_ids) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = token_embeddings + position_embeddings + token_type_embeddings + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + return embeddings + + +MMBT_START_DOCSTRING = r""" MMBT model was proposed in + `Supervised Multimodal Bitransformers for Classifying Images and Text`_ + by Douwe Kiela, Suvrat Bhooshan, Hamed Firooz, Davide Testuggine. + It's a supervised multimodal bitransformer model that fuses information from text and other image encoders, + and obtain state-of-the-art performance on various multimodal classification benchmark tasks. + + This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and + refer to the PyTorch documentation for all matter related to general usage and behavior. + + .. _`Supervised Multimodal Bitransformers for Classifying Images and Text`: + https://www.github.com/salesforce/ctrl + + .. _`torch.nn.Module`: + https://pytorch.org/docs/stable/nn.html#module + + Parameters: + config (:class:`~transformers.MMBTConfig`): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the configuration. + transformer (:class: `~nn.Module`): A text transformer that is used by MMBT. + It should have embeddings, encoder, and pooler attributes. + encoder (:class: `~nn.Module`): Encoder for the second modality. + It should take in a batch of modal inputs and return k, n dimension embeddings. +""" + +MMBT_INPUTS_DOCSTRING = r""" Inputs: + **input_modal**: ``torch.FloatTensor`` of shape ``(batch_size, ***)``: + The other modality data. It will be the shape that the encoder for that type expects. + e.g. With an Image Encoder, the shape would be (batch_size, channels, height, width) + **input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of input sequence tokens in the vocabulary. + It does not expect [CLS] token to be added as it's appended to the end of other modality embeddings. + See :func:`transformers.PreTrainedTokenizer.encode` and + :func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details. + **modal_start_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Optional start token to be added to Other Modality Embedding. [CLS] Most commonly used for Classification tasks. + **modal_end_tokens**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Optional end token to be added to Other Modality Embedding. [SEP] Most commonly used. + **attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + **token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Segment token indices to indicate different portions of the inputs. + **modal_token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``: + Segment token indices to indicate different portions of the non-text modality. + The embeddings from these tokens will be summed with the respective token embeddings for the non-text modality. + **position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: + Indices of positions of each input sequence tokens in the position embeddings. + **modal_position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, modal_sequence_length)``: + Indices of positions of each input sequence tokens in the position embeddings for the non-text modality. + **head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``: + Mask to nullify selected heads of the self-attention modules. + Mask values selected in ``[0, 1]``: + ``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**. + **inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``: + Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + **encoder_hidden_states**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``: + Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if the model + is configured as a decoder. + **encoder_attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Mask to avoid performing attention on the padding token indices of the encoder input. This mask + is used in the cross-attention if the model is configured as a decoder. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. +""" + +@add_start_docstrings("The bare MMBT Model outputting raw hidden-states without any specific head on top.", + MMBT_START_DOCSTRING, MMBT_INPUTS_DOCSTRING) +class MMBTModel(nn.Module): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the output of the last layer of the model. + **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during Bert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + transformer = BertModel.from_pretrained('bert-base-uncased') + encoder = ImageEncoder(args) + mmbt = MMBTModel(config, transformer, encoder) + """ + def __init__(self, config, transformer, encoder): + super(MMBTModel, self).__init__() + self.config = config + self.transformer = transformer + self.modal_encoder = ModalEmbeddings(config, encoder, transformer.embeddings) + + def forward(self, input_modal, input_ids=None, modal_start_tokens=None, + modal_end_tokens=None, attention_mask=None, + token_type_ids=None, modal_token_type_ids=None, + position_ids=None, modal_position_ids=None, head_mask=None, + inputs_embeds=None, encoder_hidden_states=None, + encoder_attention_mask=None): + + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_txt_shape = input_ids.size() + elif inputs_embeds is not None: + input_txt_shape = inputs_embeds.size()[:-1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + device = input_ids.device if input_ids is not None else inputs_embeds.device + + modal_embeddings = self.modal_encoder(input_modal, + start_token=modal_start_tokens, + end_token=modal_end_tokens, + position_ids=modal_position_ids, + token_type_ids=modal_token_type_ids) + + input_modal_shape = modal_embeddings.size()[:-1] + + if token_type_ids is None: + token_type_ids = torch.ones(input_txt_shape, dtype=torch.long, device=device) + + txt_embeddings = self.transformer.embeddings(input_ids=input_ids, + position_ids=position_ids, + token_type_ids=token_type_ids, + inputs_embeds=inputs_embeds) + + embedding_output = torch.cat([modal_embeddings, txt_embeddings], 1) + + input_shape = embedding_output.size()[:-1] + + if attention_mask is None: + attention_mask = torch.ones(input_shape, device=device) + else: + attention_mask = torch.cat([torch.ones(input_modal_shape, device=device, dtype=torch.long), attention_mask], dim=1) + + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(input_shape, device=device) + else: + encoder_attention_mask = torch.cat([torch.ones(input_modal_shape, device=device), encoder_attention_mask], dim=1) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + if attention_mask.dim() == 3: + extended_attention_mask = attention_mask[:, None, :, :] + + # Provided a padding mask of dimensions [batch_size, seq_length] + # - if the model is a decoder, apply a causal mask in addition to the padding mask + # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] + if attention_mask.dim() == 2: + if self.config.is_decoder: + batch_size, seq_length = input_shape + seq_ids = torch.arange(seq_length, device=device) + causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None] + extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :] + else: + extended_attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # If a 2D ou 3D attention mask is provided for the cross-attention + # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] + if encoder_attention_mask.dim() == 3: + encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :] + if encoder_attention_mask.dim() == 2: + encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :] + + encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + + encoder_outputs = self.transformer.encoder(embedding_output, + attention_mask=extended_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask) + + sequence_output = encoder_outputs[0] + pooled_output = self.transformer.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + + def get_input_embeddings(self): + return self.embeddings.word_embeddings + + def set_input_embeddings(self, value): + self.embeddings.word_embeddings = value + + +@add_start_docstrings("""MMBT Model with a sequence classification/regression head on top (a linear layer on top of + the pooled output)""", MMBT_START_DOCSTRING, MMBT_INPUTS_DOCSTRING) +class MMBTForClassification(nn.Module): + r""" + **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: + Labels for computing the sequence classification/regression loss. + Indices should be in ``[0, ..., config.num_labels - 1]``. + If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), + If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification (or regression if config.num_labels==1) loss. + **logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` + Classification (or regression if config.num_labels==1) scores (before SoftMax). + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + transformer = BertModel.from_pretrained('bert-base-uncased') + encoder = ImageEncoder(args) + model = MMBTForClassification(config, transformer, encoder) + outputs = model(input_modal, input_ids, labels=labels) + loss, logits = outputs[:2] + """ + + def __init__(self, config, transformer, encoder): + super(MMBTForClassification, self).__init__() + self.num_labels = config.num_labels + + self.mmbt = MMBTModel(config, transformer, encoder) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, input_modal, input_ids=None, modal_start_tokens=None, modal_end_tokens=None, + attention_mask=None, token_type_ids=None, modal_token_type_ids=None, position_ids=None, + modal_position_ids=None, head_mask=None, inputs_embeds=None, labels=None): + + outputs = self.mmbt(input_modal=input_modal, input_ids=input_ids, + modal_start_tokens=modal_start_tokens, + modal_end_tokens=modal_end_tokens, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + modal_token_type_ids=modal_token_type_ids, + position_ids=position_ids, + modal_position_ids=modal_position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds) + + pooled_output = outputs[1] + + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + + outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here + + if labels is not None: + if self.num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + else: + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + outputs = (loss,) + outputs + + return outputs # (loss), logits, (hidden_states), (attentions) \ No newline at end of file